Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70c0c1f9
Unverified
Commit
70c0c1f9
authored
Sep 11, 2025
by
eigen
Committed by
GitHub
Sep 11, 2025
Browse files
fix: trtllm-gen attention take zero-init workspace (#10330)
parent
760b788a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
7 deletions
+2
-7
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+2
-7
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
70c0c1f9
...
@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata:
...
@@ -58,7 +58,6 @@ class TRTLLMMLAPrefillMetadata:
class
TRTLLMMLADecodeMetadata
:
class
TRTLLMMLADecodeMetadata
:
"""Metadata for TRTLLM MLA decode operations."""
"""Metadata for TRTLLM MLA decode operations."""
workspace
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
block_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
...
@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -187,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
self
.
decode_cuda_graph_kv_indices
=
torch
.
full
(
self
.
decode_cuda_graph_kv_indices
=
torch
.
full
(
(
max_bs
,
max_blocks_per_seq
),
-
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,
max_blocks_per_seq
),
-
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
decode_cuda_graph_workspace
=
torch
.
empty
(
self
.
workspace_size
,
dtype
=
torch
.
int8
,
device
=
self
.
device
)
super
().
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
)
super
().
init_cuda_graph_state
(
max_bs
,
max_num_tokens
,
kv_indices_buf
)
...
@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -240,7 +236,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val
=
int
(
seq_lens
.
max
().
item
())
max_seq_len_val
=
int
(
seq_lens
.
max
().
item
())
metadata
=
TRTLLMMLADecodeMetadata
(
metadata
=
TRTLLMMLADecodeMetadata
(
self
.
decode_cuda_graph_workspace
,
block_kv_indices
,
block_kv_indices
,
max_seq_len_val
,
max_seq_len_val
,
)
)
...
@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -339,7 +334,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
max_seq_len_val
=
int
(
max_seq
)
max_seq_len_val
=
int
(
max_seq
)
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
forward_decode_metadata
=
TRTLLMMLADecodeMetadata
(
self
.
workspace_buffer
,
block_kv_indices
,
max_seq_len_val
block_kv_indices
,
max_seq_len_val
)
)
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_decode_metadata
forward_batch
.
decode_trtllm_mla_metadata
=
self
.
forward_decode_metadata
else
:
else
:
...
@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -513,7 +508,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
raw_out
=
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache_mla
(
raw_out
=
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache_mla
(
query
=
query
,
query
=
query
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
workspace_buffer
=
metadata
.
workspace
,
workspace_buffer
=
self
.
workspace
_buffer
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
qk_nope_head_dim
=
self
.
qk_nope_head_dim
,
kv_lora_rank
=
self
.
kv_lora_rank
,
kv_lora_rank
=
self
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
qk_rope_head_dim
=
self
.
qk_rope_head_dim
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment