Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
311de47b
"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "39d463b858c799feb2d8a051729014bb95f69d8c"
Unverified
Commit
311de47b
authored
Sep 17, 2025
by
fzyzcjy
Committed by
GitHub
Sep 16, 2025
Browse files
[2/2] Speed up trtllm_mla attention backend (#10474)
parent
373080ea
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
2 deletions
+10
-2
python/sglang/srt/layers/attention/trtllm_mla_backend.py
python/sglang/srt/layers/attention/trtllm_mla_backend.py
+10
-2
No files found.
python/sglang/srt/layers/attention/trtllm_mla_backend.py
View file @
311de47b
...
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
...
@@ -22,7 +22,7 @@ from sglang.srt.layers.attention.utils import (
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
if
is_flashinfer_available
():
if
is_flashinfer_available
():
import
flashinfer
import
flashinfer
...
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
...
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
from
sglang.srt.speculative.spec_info
import
SpecInfo
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
concat_mla_absorb_q
# Constants
# Constants
DEFAULT_WORKSPACE_SIZE_MB
=
128
# Memory workspace size in MB
DEFAULT_WORKSPACE_SIZE_MB
=
128
# Memory workspace size in MB
...
@@ -482,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
...
@@ -482,7 +487,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
q_rope_reshaped
=
q_rope
.
view
(
q_rope_reshaped
=
q_rope
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
-
layer
.
v_head_dim
)
)
query
=
torch
.
cat
([
q_nope
,
q_rope_reshaped
],
dim
=-
1
)
if
_is_cuda
and
q_nope
.
shape
[
-
1
]
==
512
and
q_rope_reshaped
.
shape
[
-
1
]
==
64
:
query
=
concat_mla_absorb_q
(
q_nope
,
q_rope_reshaped
)
else
:
query
=
torch
.
cat
([
q_nope
,
q_rope_reshaped
],
dim
=-
1
)
else
:
else
:
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
query
=
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment