Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
93d4e354
Unverified
Commit
93d4e354
authored
Aug 15, 2024
by
Ying Sheng
Committed by
GitHub
Aug 15, 2024
Browse files
[Fix] Window attention compatible with RadixAttention and chunked prefill (#1112)
parent
9195d136
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
56 deletions
+37
-56
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+19
-19
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+15
-22
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-5
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+1
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-9
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
93d4e354
...
...
@@ -15,6 +15,8 @@ limitations under the License.
"""Radix attention."""
from
typing
import
Optional
import
torch
from
flashinfer.cascade
import
merge_state
from
torch
import
nn
...
...
@@ -34,8 +36,7 @@ class RadixAttention(nn.Module):
scaling
:
float
,
num_kv_heads
:
int
,
layer_id
:
int
,
reuse
:
bool
=
False
,
sliding_window_size
:
int
=
-
1
,
sliding_window_size
:
Optional
[
int
]
=
None
,
logit_cap
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
):
...
...
@@ -48,8 +49,7 @@ class RadixAttention(nn.Module):
self
.
v_head_dim
=
v_head_dim
if
v_head_dim
!=
-
1
else
head_dim
self
.
scaling
=
scaling
self
.
layer_id
=
layer_id
self
.
reuse
=
reuse
self
.
sliding_window_size
=
sliding_window_size
self
.
sliding_window_size
=
sliding_window_size
if
sliding_window_size
else
-
1
if
(
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
)
...
...
@@ -118,16 +118,16 @@ class RadixAttention(nn.Module):
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged
=
input_metadata
.
flashinfer_prefill_wrapper_paged
if
self
.
sliding_window_size
!=
-
1
or
self
.
reuse
:
if
self
.
sliding_window_size
!=
-
1
:
prefill_wrapper_paged
=
prefill_wrapper_paged
[
0
]
else
:
if
isinstance
(
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
if
not
input_metadata
.
flashinfer_use_ragged
or
self
.
reuse
:
if
not
self
.
reuse
:
if
not
input_metadata
.
flashinfer_use_ragged
:
if
k
is
not
None
:
assert
v
is
not
None
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
prefill_wrapper_paged
.
forward
(
...
...
@@ -139,21 +139,20 @@ class RadixAttention(nn.Module):
logits_soft_cap
=
self
.
logit_cap
,
)
else
:
o1
,
s1
=
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
window_left
=
self
.
sliding_window_size
,
logits_soft_cap
=
self
.
logit_cap
,
o1
,
s1
=
(
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
if
input_metadata
.
extend_no_prefix
:
o
=
o1
else
:
# TODO window attention + radix attention will come up in next PR
assert
self
.
sliding_window_size
==
-
1
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
...
...
@@ -179,7 +178,8 @@ class RadixAttention(nn.Module):
if
isinstance
(
decode_wrapper
,
list
):
decode_wrapper
=
decode_wrapper
[
1
]
if
not
self
.
reuse
:
if
k
is
not
None
:
assert
v
is
not
None
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
o
=
decode_wrapper
.
forward
(
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
93d4e354
...
...
@@ -194,6 +194,7 @@ class InputMetadata:
if
(
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
ret
.
seq_lens
))
>
4096
and
model_runner
.
sliding_window_size
is
None
):
flashinfer_use_ragged
=
True
ret
.
init_flashinfer_handlers
(
...
...
@@ -322,22 +323,25 @@ def update_flashinfer_indices(
1
,
)
else
:
# window attention use paged only
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
if
flashinfer_use_ragged
and
wrapper_id
==
1
:
# full attention use ragged+paged
paged_kernel_lens
=
prefix_lens
if
wrapper_id
==
0
:
if
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
+
1
)
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
)
+
seq_lens
-
prefix_lens
,
)
else
:
# window attention use paged only
paged_kernel_lens
=
seq_lens
if
wrapper_id
==
0
and
forward_mode
==
ForwardMode
.
DECODE
:
paged_kernel_lens
=
torch
.
minimum
(
paged_kernel_lens
,
torch
.
tensor
(
model_runner
.
sliding_window_size
)
)
kv_start_idx
=
seq_lens
-
paged_kernel_lens
else
:
kv_start_idx
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_start_idx
=
seq_lens
-
paged_kernel_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
...
...
@@ -376,17 +380,6 @@ def update_flashinfer_indices(
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
and
wrapper_id
==
1
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_paged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
93d4e354
...
...
@@ -334,11 +334,7 @@ class ModelRunner:
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_ragged
=
None
self
.
flashinfer_prefill_wrapper_paged
=
[]
self
.
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
...
...
python/sglang/srt/models/gemma2.py
View file @
93d4e354
...
...
@@ -213,7 +213,7 @@ class Gemma2Attention(nn.Module):
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
sliding_window_size
=
get_window_size
(
config
)
if
use_sliding_window
else
-
1
,
sliding_window_size
=
get_window_size
(
config
)
if
use_sliding_window
else
None
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
...
...
python/sglang/srt/server_args.py
View file @
93d4e354
...
...
@@ -450,16 +450,8 @@ class ServerArgs:
self
.
dp_size
>
1
and
self
.
node_rank
is
not
None
),
"multi-node data parallel is not supported"
if
"gemma-2"
in
self
.
model_path
.
lower
():
logger
.
info
(
f
"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer."
)
# FIXME: compatibility with radix attention
self
.
disable_radix_cache
=
True
# FIXME: compatibility with jump forward
self
.
disable_regex_jump_forward
=
True
logger
.
info
(
f
"When using sliding window in gemma-2, turn on flashinfer."
)
self
.
disable_flashinfer
=
False
# FIXME: compatibility with chunked prefill
self
.
chunked_prefill_size
=
None
@
dataclasses
.
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment