Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c49c1d92
"llm/vscode:/vscode.git/clone" did not exist on "6d3adfbea21699db7770eb608264d5e16b8663ee"
Unverified
Commit
c49c1d92
authored
Jun 14, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Remove 200us slow concat kernel (part 2: srt) (#7020)
parent
0f1dfa1e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
11 deletions
+39
-11
python/sglang/srt/layers/attention/cutlass_mla_backend.py
python/sglang/srt/layers/attention/cutlass_mla_backend.py
+34
-10
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+5
-1
No files found.
python/sglang/srt/layers/attention/cutlass_mla_backend.py
View file @
c49c1d92
...
...
@@ -233,25 +233,49 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
:
bool
=
True
,
# For multi-head latent attention
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
):
cache_loc
=
forward_batch
.
out_cache_loc
if
k
is
not
None
:
assert
v
is
not
None
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
bs
=
forward_batch
.
batch_size
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
if
k_rope
is
not
None
:
forward_batch
.
token_to_kv_pool
.
set_mla_kv_buffer
(
layer
,
cache_loc
,
k
,
k_rope
,
)
else
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
,
)
# Reshape inputs
if
q_rope
is
not
None
:
q_nope
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
q_rope
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
else
:
reshaped_q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
reshaped_q
[:,
:,
:
layer
.
v_head_dim
]
q_rope
=
reshaped_q
[:,
:,
layer
.
v_head_dim
:]
reshape_q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
q_nope
=
q_nope
.
to
(
self
.
q_data_type
)
q_rope
=
q_rope
.
to
(
self
.
q_data_type
)
k_cache
=
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
)
o
=
cutlass_mla_decode
(
q_nope_and_q_pe
=
reshape_q
.
to
(
self
.
q_data_type
),
q_nope
=
q_nope
,
q_pe
=
q_rope
,
kv_c_and_k_pe_cache
=
k_cache
.
view
(
-
1
,
PAGE_SIZE
,
self
.
kv_cache_dim
),
seq_lens
=
forward_batch
.
seq_lens
.
to
(
torch
.
int32
),
page_table
=
self
.
forward_metadata
.
block_kv_indices
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
c49c1d92
...
...
@@ -1013,7 +1013,11 @@ class DeepseekV2AttentionMLA(nn.Module):
def
forward_absorb_core
(
self
,
q_pe
,
k_pe
,
q_nope_out
,
k_nope
,
forward_batch
,
zero_allocator
):
if
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
:
if
(
self
.
attention_backend
==
"fa3"
or
self
.
attention_backend
==
"flashinfer"
or
self
.
attention_backend
==
"cutlass_mla"
):
attn_output
=
self
.
attn_mqa
(
q_nope_out
,
k_nope
,
k_nope
,
forward_batch
,
q_rope
=
q_pe
,
k_rope
=
k_pe
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment