Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
714f3e63
Unverified
Commit
714f3e63
authored
Feb 18, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 18, 2025
Browse files
feat: support flashinfer mla with prefix cache (#3643)
parent
c38f3aed
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
108 additions
and
32 deletions
+108
-32
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+101
-30
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-2
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
714f3e63
...
...
@@ -54,7 +54,9 @@ class DecodeMetadata:
@
dataclass
class
PrefillMetadata
:
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
]
prefill_wrappers
:
List
[
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
]
use_ragged
:
bool
extend_no_prefix
:
bool
...
...
@@ -160,16 +162,36 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
decode_wrappers
=
[]
for
_
in
range
(
self
.
num_wrappers
):
if
not
skip_prefill
:
self
.
prefill_wrappers_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
backend
=
"fa2"
,
if
(
self
.
enable_flashinfer_mla
and
not
global_server_args_dict
[
"disable_radix_cache"
]
):
# use mla paged prefill
self
.
prefill_wrappers_paged
.
append
(
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
backend
=
"fa2"
,
)
)
self
.
prefill_wrappers_verify
.
append
(
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
backend
=
"fa2"
,
)
)
else
:
self
.
prefill_wrappers_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
backend
=
"fa2"
,
)
)
self
.
prefill_wrappers_verify
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
)
self
.
prefill_wrappers_verify
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
if
self
.
enable_flashinfer_mla
:
self
.
decode_wrappers
.
append
(
BatchMLAPagedAttentionWrapper
(
self
.
workspace_buffer
,
backend
=
"fa2"
)
...
...
@@ -237,7 +259,10 @@ class FlashInferAttnBackend(AttentionBackend):
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
if
self
.
is_multimodal
:
if
self
.
is_multimodal
or
(
self
.
enable_flashinfer_mla
and
not
global_server_args_dict
[
"disable_radix_cache"
]
):
use_ragged
=
False
extend_no_prefix
=
False
else
:
...
...
@@ -419,23 +444,43 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap
=
layer
.
logit_cap
o1
,
_
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
v
.
view
(
-
1
,
layer
.
tp_
v
_head_num
,
layer
.
v_
head_dim
),
causal
=
True
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
)
o
=
o1
if
global_server_args_dict
[
"disable_radix_cache"
]:
# use mla ragged prefill
o
,
_
=
self
.
prefill_wrapper_ragged
.
forward_return_lse
(
q
.
view
(
-
1
,
layer
.
tp_
q
_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
)
,
v
.
view
(
-
1
,
layer
.
tp_v_head_num
,
layer
.
v_head_dim
)
,
causal
=
True
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
logits_soft_cap
,
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
else
:
# use mla paged prefill
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrappers
[
self
.
_get_wrapper_idx
(
layer
)
]
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
o
=
prefill_wrapper_paged
.
run
(
qall
[:,
:,
:
layer
.
v_head_dim
],
qall
[:,
:,
layer
.
v_head_dim
:],
k_buf
[:,
:,
:
layer
.
v_head_dim
],
k_buf
[:,
:,
layer
.
v_head_dim
:],
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
...
@@ -800,7 +845,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
...
...
@@ -814,7 +861,9 @@ class FlashInferIndicesUpdaterPrefill:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
]
],
use_ragged
:
bool
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
...
...
@@ -923,7 +972,9 @@ class FlashInferIndicesUpdaterPrefill:
def
call_begin_forward
(
self
,
wrapper_ragged
:
BatchPrefillWithRaggedKVCacheWrapper
,
wrapper_paged
:
BatchPrefillWithPagedKVCacheWrapper
,
wrapper_paged
:
Union
[
BatchPrefillWithPagedKVCacheWrapper
,
BatchMLAPagedAttentionWrapper
],
req_pool_indices
:
torch
.
Tensor
,
paged_kernel_lens
:
torch
.
Tensor
,
paged_kernel_lens_sum
:
int
,
...
...
@@ -1004,6 +1055,26 @@ class FlashInferIndicesUpdaterPrefill:
custom_mask
=
custom_mask
,
non_blocking
=
True
,
)
elif
(
global_config
.
enable_flashinfer_mla
and
not
global_server_args_dict
[
"disable_radix_cache"
]
):
# mla paged prefill
kv_len_arr
=
kv_indptr
[
1
:]
-
kv_indptr
[:
-
1
]
wrapper_paged
.
plan
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_len_arr
,
self
.
num_qo_heads
,
512
,
64
,
1
,
True
,
1
/
math
.
sqrt
(
192
),
self
.
data_type
,
self
.
data_type
,
)
class
FlashInferMultiStepDraftBackend
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
714f3e63
...
...
@@ -66,6 +66,7 @@ global_server_args_dict = {
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"device"
:
ServerArgs
.
device
,
"enable_flashinfer_mla"
:
ServerArgs
.
enable_flashinfer_mla
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
}
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
714f3e63
...
...
@@ -177,6 +177,7 @@ class ModelRunner:
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"device"
:
server_args
.
device
,
"enable_flashinfer_mla"
:
server_args
.
enable_flashinfer_mla
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
}
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
714f3e63
...
...
@@ -511,8 +511,11 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
if
global_server_args_dict
[
"enable_flashinfer_mla"
]:
if
forward_batch
.
forward_mode
.
is_extend
():
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
if
global_server_args_dict
[
"disable_radix_cache"
]:
if
forward_batch
.
forward_mode
.
is_extend
():
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
else
:
return
self
.
forward_absorb
(
positions
,
hidden_states
,
forward_batch
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment