Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
96a2093e
Unverified
Commit
96a2093e
authored
Aug 14, 2024
by
Ying Sheng
Committed by
GitHub
Aug 14, 2024
Browse files
[Fix] Compatibility of window attention and cuda graph (#1090)
parent
a34dd86a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
70 additions
and
39 deletions
+70
-39
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+11
-5
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+43
-12
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-7
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+9
-13
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+3
-1
python/sglang/test/long_prompt.txt
python/sglang/test/long_prompt.txt
+0
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+1
-1
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
96a2093e
...
...
@@ -34,6 +34,7 @@ class RadixAttention(nn.Module):
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
reuse
:
bool
=
False
,
sliding_window_size
:
int
=
-
1
,
logit_cap
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
...
...
@@ -47,6 +48,7 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
reuse
=
reuse
self
.
sliding_window_size
=
sliding_window_size
if
(
...
...
@@ -127,8 +129,9 @@ class RadixAttention(nn.Module):
if
isinstance
(
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
if
not
input_metadata
.
flashinfer_use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
not
input_metadata
.
flashinfer_use_ragged
or
self
.
reuse
:
if
not
self
.
reuse
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
...
...
@@ -179,7 +182,8 @@ class RadixAttention(nn.Module):
if
isinstance
(
decode_wrapper
,
list
):
decode_wrapper
=
decode_wrapper
[
1
]
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
not
self
.
reuse
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
...
...
@@ -191,8 +195,10 @@ class RadixAttention(nn.Module):
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
def
forward
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
if
k
is
not
None
:
assert
v
is
not
None
k
=
k
.
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
qk_head_dim
)
v
=
v
.
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
v_head_dim
)
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
return
self
.
extend_forward
(
q
,
k
,
v
,
input_metadata
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
96a2093e
...
...
@@ -107,9 +107,6 @@ class CudaGraphRunner:
)
# FlashInfer inputs
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
)
self
.
flashinfer_kv_indptr
=
torch
.
zeros
(
(
self
.
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
@@ -121,6 +118,23 @@ class CudaGraphRunner:
self
.
flashinfer_kv_last_page_len
=
torch
.
ones
(
(
self
.
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
model_runner
.
sliding_window_size
is
None
:
self
.
flashinfer_workspace_buffer
=
(
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
]
)
else
:
self
.
flashinfer_workspace_buffers
=
[
self
.
model_runner
.
flashinfer_workspace_buffers
[
0
],
self
.
model_runner
.
flashinfer_workspace_buffers
[
2
],
]
self
.
flashinfer_kv_indptr
=
[
self
.
flashinfer_kv_indptr
,
self
.
flashinfer_kv_indptr
.
clone
(),
]
self
.
flashinfer_kv_indices
=
[
self
.
flashinfer_kv_indices
,
self
.
flashinfer_kv_indices
.
clone
(),
]
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
...
...
@@ -171,15 +185,32 @@ class CudaGraphRunner:
use_tensor_cores
=
True
else
:
use_tensor_cores
=
False
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
if
self
.
model_runner
.
sliding_window_size
is
None
:
flashinfer_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[:
bs
],
)
else
:
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffers
[
i
],
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
flashinfer_kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
flashinfer_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
flashinfer_kv_last_page_len
[
:
bs
],
)
)
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
96a2093e
...
...
@@ -154,7 +154,6 @@ class InputMetadata:
model_runner
:
"ModelRunner"
,
batch
:
ScheduleBatch
,
forward_mode
:
ForwardMode
,
sliding_window_size
:
Optional
[
int
]
=
None
,
):
ret
=
cls
(
forward_mode
=
forward_mode
,
...
...
@@ -198,7 +197,7 @@ class InputMetadata:
):
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
model_runner
,
prefix_lens
,
flashinfer_use_ragged
,
sliding_window_size
model_runner
,
prefix_lens
,
flashinfer_use_ragged
)
return
ret
...
...
@@ -221,7 +220,6 @@ class InputMetadata:
model_runner
,
prefix_lens
,
flashinfer_use_ragged
,
sliding_window_size
=
None
,
):
update_flashinfer_indices
(
self
.
forward_mode
,
...
...
@@ -230,7 +228,6 @@ class InputMetadata:
self
.
seq_lens
,
prefix_lens
,
flashinfer_use_ragged
=
flashinfer_use_ragged
,
sliding_window_size
=
sliding_window_size
,
)
(
...
...
@@ -254,7 +251,6 @@ def update_flashinfer_indices(
prefix_lens
,
flashinfer_decode_wrapper
=
None
,
flashinfer_use_ragged
=
False
,
sliding_window_size
=
None
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
...
@@ -262,7 +258,7 @@ def update_flashinfer_indices(
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
if
sliding_window_size
is
None
:
if
model_runner
.
sliding_window_size
is
None
:
if
flashinfer_use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
...
...
@@ -335,7 +331,7 @@ def update_flashinfer_indices(
if
wrapper_id
==
0
and
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
torch
.
minimum
(
paged_kernel_lens
,
torch
.
tensor
(
sliding_window_size
)
paged_kernel_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
)
)
kv_start_idx
=
seq_lens
-
paged_kernel_lens
else
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
96a2093e
...
...
@@ -187,6 +187,11 @@ class ModelRunner:
scheduler_config
=
None
,
cache_config
=
None
,
)
self
.
sliding_window_size
=
(
self
.
model
.
get_window_size
()
if
hasattr
(
self
.
model
,
"get_window_size"
)
else
None
)
self
.
is_generation
=
is_generation_model
(
self
.
model_config
.
hf_config
.
architectures
)
...
...
@@ -295,12 +300,6 @@ class ModelRunner:
return
c
def
init_flashinfer
(
self
):
self
.
sliding_window_size
=
(
self
.
model
.
get_window_size
()
if
hasattr
(
self
.
model
,
"get_window_size"
)
else
None
)
if
self
.
server_args
.
disable_flashinfer
:
assert
(
self
.
sliding_window_size
is
None
...
...
@@ -339,7 +338,7 @@ class ModelRunner:
use_tensor_cores
=
use_tensor_cores
,
)
else
:
workspace_buffers
=
torch
.
empty
(
self
.
flashinfer_
workspace_buffers
=
torch
.
empty
(
4
,
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
...
...
@@ -351,17 +350,17 @@ class ModelRunner:
for
i
in
range
(
2
):
self
.
flashinfer_prefill_wrapper_ragged
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
0
],
"NHD"
self
.
flashinfer_
workspace_buffers
[
2
*
i
+
0
],
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
1
],
"NHD"
self
.
flashinfer_
workspace_buffers
[
2
*
i
+
1
],
"NHD"
)
)
self
.
flashinfer_decode_wrapper
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffers
[
2
*
i
+
0
],
self
.
flashinfer_
workspace_buffers
[
2
*
i
+
0
],
"NHD"
,
use_tensor_cores
=
use_tensor_cores
,
)
...
...
@@ -404,7 +403,6 @@ class ModelRunner:
self
,
batch
,
ForwardMode
.
DECODE
,
sliding_window_size
=
self
.
sliding_window_size
,
)
return
self
.
model
.
forward
(
...
...
@@ -417,7 +415,6 @@ class ModelRunner:
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
sliding_window_size
=
self
.
sliding_window_size
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
@@ -429,7 +426,6 @@ class ModelRunner:
self
,
batch
,
forward_mode
=
ForwardMode
.
EXTEND
,
sliding_window_size
=
self
.
sliding_window_size
,
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
...
...
python/sglang/srt/server_args.py
View file @
96a2093e
...
...
@@ -453,10 +453,12 @@ class ServerArgs:
logger
.
info
(
f
"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
)
# FIXME: compatibility with radix attention
self
.
disable_radix_cache
=
True
# FIXME: compatibility with jump forward
self
.
disable_regex_jump_forward
=
True
self
.
disable_flashinfer
=
False
self
.
disable_cuda_graph
=
True
# FIXME: compatibility with chunked prefill
self
.
chunked_prefill_size
=
None
...
...
python/sglang/test/long_prompt
→
python/sglang/test/long_prompt
.txt
View file @
96a2093e
File moved
python/sglang/test/runners.py
View file @
96a2093e
...
...
@@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [
]
dirpath
=
os
.
path
.
dirname
(
__file__
)
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt"
),
"r"
)
as
f
:
with
open
(
os
.
path
.
join
(
dirpath
,
"long_prompt
.txt
"
),
"r"
)
as
f
:
long_prompt
=
f
.
read
()
DEFAULT_PROMPTS
.
append
(
long_prompt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment