Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e30c273b
Unverified
Commit
e30c273b
authored
May 09, 2025
by
xu-yfei
Committed by
GitHub
May 08, 2025
Browse files
opt flashinfer mla cat (#5822)
Co-authored-by:
xuyongfei.xyf
<
xuyongfei.xyf@antgroup.com
>
parent
0ab3f437
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
14 deletions
+60
-14
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+59
-13
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+1
-1
No files found.
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
e30c273b
...
@@ -339,22 +339,38 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -339,22 +339,38 @@ class FlashInferMLAAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
logits_soft_cap
=
layer
.
logit_cap
logits_soft_cap
=
layer
.
logit_cap
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrapper
prefill_wrapper_paged
=
self
.
forward_metadata
.
prefill_wrapper
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_buf
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
# Save kv cache
# Save kv cache
if
save_kv_cache
and
k
is
not
None
:
if
save_kv_cache
and
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
if
k_rope
is
not
None
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
cache_loc
,
k
,
k_rope
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
if
q_rope
is
not
None
:
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
if
self
.
forward_metadata
.
use_ragged
:
if
self
.
forward_metadata
.
use_ragged
:
# ragged prefill
# ragged prefill
if
q_rope
is
not
None
:
q
=
torch
.
cat
([
q
,
q_rope
],
dim
=-
1
)
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
k_rope
is
not
None
:
k
=
torch
.
cat
([
k
,
k_rope
],
dim
=-
1
)
o
=
self
.
prefill_wrapper_ragged
.
forward
(
o
=
self
.
prefill_wrapper_ragged
.
forward
(
qall
,
qall
,
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
),
...
@@ -365,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -365,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
)
)
else
:
else
:
# mla paged prefill
# mla paged prefill
if
q_rope
is
None
:
qall
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q
,
q_rope
=
(
qall
[:,
:,
:
layer
.
v_head_dim
],
qall
[:,
:,
layer
.
v_head_dim
:],
)
o
=
q
.
new_empty
(
q
.
shape
)
o
=
prefill_wrapper_paged
.
run
(
o
=
prefill_wrapper_paged
.
run
(
q
all
[:,
:,
:
layer
.
v_head_dim
]
,
q
,
q
all
[:,
:,
layer
.
v_head_dim
:]
,
q
_rope
,
k_buf
[:,
:,
:
layer
.
v_head_dim
],
k_buf
[:,
:,
:
layer
.
v_head_dim
],
k_buf
[:,
:,
layer
.
v_head_dim
:],
k_buf
[:,
:,
layer
.
v_head_dim
:],
out
=
o
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
@@ -382,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -382,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
decode_wrapper
=
self
.
forward_metadata
.
decode_wrapper
decode_wrapper
=
self
.
forward_metadata
.
decode_wrapper
cache_loc
=
forward_batch
.
out_cache_loc
cache_loc
=
forward_batch
.
out_cache_loc
...
@@ -389,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -389,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if
k
is
not
None
:
if
k
is
not
None
:
assert
v
is
not
None
assert
v
is
not
None
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
if
k_rope
is
not
None
:
layer
,
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
cache_loc
,
layer
,
k
,
cache_loc
,
v
,
k
,
)
k_rope
,
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
# Reshape inputs
# Reshape inputs
reshaped_q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
if
q_rope
is
not
None
:
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
else
:
reshaped_q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
reshaped_q
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
reshaped_q
[:,
:,
layer
.
v_head_dim
:]
k_buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
k_buffer
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
o
=
q_nope
.
new_empty
(
q_nope
.
shape
)
# Direct call to run without the wrapper
# Direct call to run without the wrapper
o
=
decode_wrapper
.
run
(
o
=
decode_wrapper
.
run
(
reshaped_q
[:,
:,
:
layer
.
v_head_dim
]
,
q_nope
,
reshaped_q
[:,
:,
layer
.
v_head_dim
:]
,
q_rope
,
k_buffer
[:,
:,
:
layer
.
v_head_dim
],
k_buffer
[:,
:,
:
layer
.
v_head_dim
],
k_buffer
[:,
:,
layer
.
v_head_dim
:],
k_buffer
[:,
:,
layer
.
v_head_dim
:],
out
=
o
,
)
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
v_head_dim
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e30c273b
...
@@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_nope_out
=
q_nope_out
.
transpose
(
0
,
1
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
if
self
.
attention_backend
==
"fa3"
:
if
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
:
attn_output
=
self
.
attn_mqa
(
attn_output
=
self
.
attn_mqa
(
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment