Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
00aec6ad
Unverified
Commit
00aec6ad
authored
Jul 02, 2025
by
Ke Bao
Committed by
GitHub
Jul 01, 2025
Browse files
Apply dsv3_fused_a_gemm kernel (#7635)
parent
1a08358a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
2 deletions
+18
-2
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+18
-2
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
00aec6ad
...
...
@@ -96,6 +96,7 @@ from sglang.srt.utils import (
bind_or_assign
,
cpu_has_amx_support
,
get_bool_env_var
,
get_device_sm
,
get_int_env_var
,
is_cpu
,
is_cuda
,
...
...
@@ -112,7 +113,7 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu
=
is_cpu
()
if
_is_cuda
:
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
merge_state_v2
from
sgl_kernel
import
awq_dequantize
,
bmm_fp8
,
dsv3_fused_a_gemm
,
merge_state_v2
elif
_is_cpu
and
_is_cpu_amx_available
:
pass
else
:
...
...
@@ -875,6 +876,15 @@ class DeepseekV2AttentionMLA(nn.Module):
weight_names
=
[
"w_kc"
,
"w_vc"
],
transpose_dims
=
[[
1
,
2
],
[
1
,
2
]]
)
self
.
use_min_latency_fused_a_gemm
=
(
hasattr
(
self
,
"fused_qkv_a_proj_with_mqa"
)
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
dtype
==
torch
.
bfloat16
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
shape
[
0
]
==
2112
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
shape
[
1
]
==
7168
and
is_cuda
and
get_device_sm
()
>=
90
)
self
.
qkv_proj_with_rope_is_int8
=
(
hasattr
(
self
,
"fused_qkv_a_proj_with_mqa"
)
and
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
dtype
==
torch
.
int8
...
...
@@ -1114,7 +1124,13 @@ class DeepseekV2AttentionMLA(nn.Module):
from
sglang.srt.model_executor.cuda_graph_runner
import
get_is_capture_mode
if
self
.
q_lora_rank
is
not
None
:
q
,
latent_cache
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
].
split
(
if
hidden_states
.
shape
[
0
]
<=
16
and
self
.
use_min_latency_fused_a_gemm
:
fused_qkv_a_proj_out
=
dsv3_fused_a_gemm
(
hidden_states
,
self
.
fused_qkv_a_proj_with_mqa
.
weight
.
T
)
else
:
fused_qkv_a_proj_out
=
self
.
fused_qkv_a_proj_with_mqa
(
hidden_states
)[
0
]
q
,
latent_cache
=
fused_qkv_a_proj_out
.
split
(
[
self
.
q_lora_rank
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
],
dim
=-
1
)
k_nope
=
latent_cache
[...,
:
self
.
kv_lora_rank
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment