Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
93cec433
Unverified
Commit
93cec433
authored
Jun 14, 2025
by
fzyzcjy
Committed by
GitHub
Jun 13, 2025
Browse files
Support new DeepGEMM (#7172)
parent
ba589b88
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
59 additions
and
19 deletions
+59
-19
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+8
-1
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
+2
-0
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+1
-4
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
...g/srt/layers/quantization/deep_gemm_wrapper/configurer.py
+7
-1
python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
...g/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
+18
-8
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+20
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+1
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
93cec433
...
@@ -1231,6 +1231,7 @@ class DeepEPMoE(EPMoE):
...
@@ -1231,6 +1231,7 @@ class DeepEPMoE(EPMoE):
down_input_scale
,
down_input_scale
,
scale_block_size
,
scale_block_size
,
masked_m
,
masked_m
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
)
del
gateup_output
del
gateup_output
...
@@ -1238,7 +1239,13 @@ class DeepEPMoE(EPMoE):
...
@@ -1238,7 +1239,13 @@ class DeepEPMoE(EPMoE):
n
=
self
.
w2_weight
.
size
(
1
)
n
=
self
.
w2_weight
.
size
(
1
)
down_input_fp8
=
(
down_input_fp8
=
(
down_input
,
down_input
,
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
down_input_scale
),
(
down_input_scale
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
down_input_scale
)
),
)
)
down_output
=
torch
.
empty
(
down_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
down_input
.
device
,
dtype
=
torch
.
bfloat16
(
num_groups
,
m
,
n
),
device
=
down_input
.
device
,
dtype
=
torch
.
bfloat16
...
...
python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py
View file @
93cec433
...
@@ -584,6 +584,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -584,6 +584,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
use_fp8
=
use_fp8
,
use_fp8
=
use_fp8
,
async_finish
=
not
self
.
return_recv_hook
,
async_finish
=
not
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
return_recv_hook
=
self
.
return_recv_hook
,
round_scale
=
deep_gemm_wrapper
.
DEEPGEMM_V202506
,
use_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_V202506
,
)
)
)
)
return
packed_recv_hidden
,
packed_recv_count
,
event
,
hook
return
packed_recv_hidden
,
packed_recv_count
,
event
,
hook
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
93cec433
...
@@ -12,6 +12,7 @@ import torch
...
@@ -12,6 +12,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.math_utils
import
ceil_div
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
...
@@ -518,10 +519,6 @@ def fused_moe_kernel(
...
@@ -518,10 +519,6 @@ def fused_moe_kernel(
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
triton
.
jit
@
triton
.
jit
def
moe_align_block_size_stage1
(
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
topk_ids_ptr
,
...
...
python/sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py
View file @
93cec433
...
@@ -21,6 +21,12 @@ def _compute_enable_deep_gemm():
...
@@ -21,6 +21,12 @@ def _compute_enable_deep_gemm():
ENABLE_JIT_DEEPGEMM
=
_compute_enable_deep_gemm
()
ENABLE_JIT_DEEPGEMM
=
_compute_enable_deep_gemm
()
DEEPGEMM_V202506
=
False
try
:
from
deep_gemm
import
fp8_gemm_nt
# They have not given a name to this breaking change
DEEPGEMM_V202506
=
True
except
ImportError
:
DEEPGEMM_V202506
=
False
DEEPGEMM_SCALE_UE8M0
=
DEEPGEMM_V202506
DEEPGEMM_SCALE_UE8M0
=
DEEPGEMM_V202506
python/sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py
View file @
93cec433
...
@@ -16,6 +16,16 @@ logger = logging.getLogger(__name__)
...
@@ -16,6 +16,16 @@ logger = logging.getLogger(__name__)
if
ENABLE_JIT_DEEPGEMM
:
if
ENABLE_JIT_DEEPGEMM
:
import
deep_gemm
import
deep_gemm
if
DEEPGEMM_V202506
:
from
deep_gemm
import
fp8_gemm_nt
as
_gemm_nt_f8f8bf16_raw
from
deep_gemm
import
(
fp8_m_grouped_gemm_nt_masked
as
_grouped_gemm_nt_f8f8bf16_masked_raw
,
)
from
deep_gemm
import
(
m_grouped_fp8_gemm_nt_contiguous
as
_grouped_gemm_nt_f8f8bf16_contig_raw
,
)
else
:
from
deep_gemm
import
gemm_fp8_fp8_bf16_nt
as
_gemm_nt_f8f8bf16_raw
from
deep_gemm
import
gemm_fp8_fp8_bf16_nt
as
_gemm_nt_f8f8bf16_raw
from
deep_gemm
import
get_col_major_tma_aligned_tensor
from
deep_gemm
import
get_col_major_tma_aligned_tensor
from
deep_gemm
import
(
from
deep_gemm
import
(
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
93cec433
...
@@ -765,7 +765,15 @@ def prepare_block_fp8_matmul_inputs(
...
@@ -765,7 +765,15 @@ def prepare_block_fp8_matmul_inputs(
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
assert
A
.
is_contiguous
()
assert
A
.
is_contiguous
()
if
As
.
dtype
==
torch
.
float
:
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
elif
Bs
.
dtype
==
torch
.
int
:
assert
(
triton
.
cdiv
(
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
),
4
)
==
As
.
shape
[
-
1
]
),
f
"
{
A
.
shape
=
}
{
As
.
shape
=
}
{
block_size
=
}
"
else
:
raise
NotImplementedError
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
...
@@ -773,8 +781,17 @@ def prepare_block_fp8_matmul_inputs(
...
@@ -773,8 +781,17 @@ def prepare_block_fp8_matmul_inputs(
assert
B
.
is_contiguous
()
assert
B
.
is_contiguous
()
assert
Bs
.
ndim
==
2
assert
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
N
,
K
=
B
.
shape
if
Bs
.
dtype
==
torch
.
float
:
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
elif
Bs
.
dtype
==
torch
.
int
:
assert
N
==
Bs
.
shape
[
0
],
f
"
{
B
.
shape
=
}
{
Bs
.
shape
=
}
{
block_size
=
}
"
assert
(
triton
.
cdiv
(
triton
.
cdiv
(
K
,
block_k
),
4
)
==
Bs
.
shape
[
1
]
),
f
"
{
B
.
shape
=
}
{
Bs
.
shape
=
}
{
block_size
=
}
"
else
:
raise
NotImplementedError
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
93cec433
...
@@ -238,6 +238,7 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
...
@@ -238,6 +238,7 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
block_size
[
1
],
block_size
[
1
],
column_major_scales
=
True
,
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
scale_tma_aligned
=
True
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
)
if
get_bool_env_var
(
"SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"
):
if
get_bool_env_var
(
"SGLANG_W8A8_DEEPGEMM_SANITY_CHECK_UE8M0"
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
93cec433
...
@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
...
@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.moe.topk
import
select_experts
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
...
@@ -1932,7 +1932,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1932,7 +1932,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
w_vc
=
bind_or_assign
(
self_attn
.
w_vc
,
w_vc
.
contiguous
())
self_attn
.
w_vc
=
bind_or_assign
(
self_attn
.
w_vc
,
w_vc
.
contiguous
())
self_attn
.
use_deep_gemm_bmm
=
True
self_attn
.
use_deep_gemm_bmm
=
True
if
False
:
# TODO (pr-chain)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
self
.
_weight_requant_ue8m0
()
self
.
_weight_requant_ue8m0
()
def
_weight_requant_ue8m0
(
self
):
def
_weight_requant_ue8m0
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment