Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b4c9f38a
Unverified
Commit
b4c9f38a
authored
Aug 08, 2025
by
Kaixi Hou
Committed by
GitHub
Aug 08, 2025
Browse files
[NVIDIA] Fix missing `get_col_major_tma_aligned_tensor` for Blackwell deepgemm in EpMoE (#8955)
parent
11325474
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
4 deletions
+49
-4
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+49
-4
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
b4c9f38a
...
...
@@ -55,6 +55,22 @@ if _use_aiter:
logger
=
logging
.
getLogger
(
__name__
)
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@
torch
.
compile
def
_cast_to_e8m0_with_rounding_up
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
temp
=
x
.
to
(
torch
.
float32
).
view
(
torch
.
int32
)
exp
=
torch
.
bitwise_right_shift
(
temp
,
23
)
mant
=
torch
.
bitwise_and
(
temp
,
0x7FFFFF
)
is_ru
=
torch
.
logical_and
(
torch
.
logical_and
((
mant
>
0
),
(
exp
!=
0xFE
)),
~
torch
.
logical_and
((
exp
==
0
),
(
mant
<=
0x400000
)),
)
exp
=
torch
.
where
(
is_ru
,
exp
+
1
,
exp
)
new_x
=
exp
.
to
(
torch
.
uint8
).
view
(
torch
.
int
)
return
new_x
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
class
EPMoE
(
FusedMoE
):
"""
MoE Expert Parallel Impl
...
...
@@ -204,10 +220,22 @@ class EPMoE(FusedMoE):
dispose_tensor
(
hidden_states
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
:
b
,
s_mn
,
s_k
=
gateup_input_scale
.
shape
assert
(
s_mn
%
4
==
0
and
s_k
%
4
==
0
),
f
"scales must be aligned to 4, but got (
{
b
}
,
{
s_mn
}
,
{
s_k
}
)"
# GroupGemm-0
gateup_input_fp8
=
(
gateup_input
,
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
gateup_input_scale
),
(
_cast_to_e8m0_with_rounding_up
(
gateup_input_scale
)
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
gateup_input_scale
)
),
)
num_groups
,
m
,
k
=
gateup_input_fp8
[
0
].
size
()
n
=
self
.
w13_weight
.
size
(
1
)
...
...
@@ -215,7 +243,12 @@ class EPMoE(FusedMoE):
(
num_groups
,
m
,
n
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
gateup_input_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
gateup_input_fp8
,
self
.
w13_weight_fp8
,
gateup_output
,
masked_m
,
expected_m
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
else
None
,
)
del
gateup_input
del
gateup_input_fp8
...
...
@@ -246,6 +279,7 @@ class EPMoE(FusedMoE):
down_input_scale
,
scale_block_size
,
masked_m
,
scale_ue8m0
=
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
,
)
del
gateup_output
...
...
@@ -253,13 +287,24 @@ class EPMoE(FusedMoE):
n
=
self
.
w2_weight
.
size
(
1
)
down_input_fp8
=
(
down_input
,
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
down_input_scale
),
(
down_input_scale
if
deep_gemm_wrapper
.
DEEPGEMM_SCALE_UE8M0
else
deep_gemm_wrapper
.
get_col_major_tma_aligned_tensor
(
down_input_scale
)
),
)
down_output
=
torch
.
empty
(
(
num_groups
,
m
,
n
),
device
=
hidden_states_device
,
dtype
=
torch
.
bfloat16
)
deep_gemm_wrapper
.
grouped_gemm_nt_f8f8bf16_masked
(
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
down_input_fp8
,
self
.
w2_weight_fp8
,
down_output
,
masked_m
,
expected_m
,
recipe
=
(
1
,
128
,
128
)
if
deep_gemm_wrapper
.
DEEPGEMM_BLACKWELL
else
None
,
)
del
down_input
del
down_input_fp8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment