Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
1ba137e9
"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "968e1818261e6e4f4bbb4ec2aacb2e017667d6b8"
Unverified
Commit
1ba137e9
authored
Sep 17, 2025
by
Shu Wang
Committed by
GitHub
Sep 17, 2025
Browse files
Enable trtllm mla prefix extend (#10526)
parent
de28f8e7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
5 deletions
+40
-5
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+40
-5
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
1ba137e9
...
@@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
save_kv_cache
:
bool
=
True
,
save_kv_cache
:
bool
=
True
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
q_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
k_rope
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
):
if
(
if
(
forward_batch
.
forward_mode
.
is_target_verify
()
forward_batch
.
forward_mode
.
is_target_verify
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
or
forward_batch
.
forward_mode
.
is_draft_extend
()
...
@@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
return_lse
=
forward_batch
.
mha_return_lse
,
return_lse
=
forward_batch
.
mha_return_lse
,
)
)
else
:
else
:
# replace with trtllm ragged attention once accuracy is resolved.
if
not
(
output
=
super
().
forward_extend
(
forward_batch
.
attn_attend_prefix_cache
is
not
None
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
and
forward_batch
.
mha_return_lse
)
):
output
=
super
().
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
,
save_kv_cache
,
q_rope
,
k_rope
)
else
:
# MHA for chunked prefix kv cache when running model with MLA
assert
forward_batch
.
prefix_chunk_idx
is
not
None
assert
forward_batch
.
prefix_chunk_cu_seq_lens
is
not
None
assert
q_rope
is
None
assert
k_rope
is
None
chunk_idx
=
forward_batch
.
prefix_chunk_idx
q
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
k
=
k
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
head_dim
).
to
(
q
.
dtype
)
v
=
v
.
view
(
-
1
,
layer
.
tp_k_head_num
,
layer
.
v_head_dim
).
to
(
q
.
dtype
)
output_shape
=
(
q
.
shape
[
0
],
layer
.
tp_q_head_num
,
layer
.
v_head_dim
)
output
=
flashinfer
.
prefill
.
trtllm_ragged_attention_deepseek
(
query
=
q
,
key
=
k
,
value
=
v
,
workspace_buffer
=
self
.
workspace_buffer
,
seq_lens
=
forward_batch
.
prefix_chunk_seq_lens
[
chunk_idx
],
max_q_len
=
self
.
forward_prefill_metadata
.
max_seq_len
,
max_kv_len
=
forward_batch
.
prefix_chunk_max_seq_lens
[
chunk_idx
],
bmm1_scale
=
layer
.
scaling
,
bmm2_scale
=
1.0
,
o_sf_scale
=-
1.0
,
batch_size
=
forward_batch
.
batch_size
,
window_left
=-
1
,
cum_seq_lens_q
=
self
.
forward_prefill_metadata
.
cum_seq_lens
,
cum_seq_lens_kv
=
forward_batch
.
prefix_chunk_cu_seq_lens
[
chunk_idx
],
enable_pdl
=
False
,
is_causal
=
False
,
return_lse
=
True
,
out
=
torch
.
zeros
(
*
output_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
),
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment