Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0edc0cd5
Unverified
Commit
0edc0cd5
authored
Aug 09, 2025
by
Jee Jee Li
Committed by
GitHub
Aug 09, 2025
Browse files
[Bugfix] Fix CI moe kernel failure (#22556)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
7920e9b1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
142 additions
and
64 deletions
+142
-64
tests/kernels/moe/test_gpt_oss_triton_kernels.py
tests/kernels/moe/test_gpt_oss_triton_kernels.py
+142
-64
No files found.
tests/kernels/moe/test_gpt_oss_triton_kernels.py
View file @
0edc0cd5
...
@@ -5,6 +5,15 @@ from dataclasses import dataclass, fields
...
@@ -5,6 +5,15 @@ from dataclasses import dataclass, fields
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm.utils
import
has_triton_kernels
if
not
has_triton_kernels
():
pytest
.
skip
(
"triton_kernels not found, skipping all related tests"
,
allow_module_level
=
True
,
)
import
triton_kernels.swiglu
import
triton_kernels.swiglu
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
from
triton_kernels.numerics
import
InFlexData
from
triton_kernels.numerics
import
InFlexData
...
@@ -65,7 +74,7 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
...
@@ -65,7 +74,7 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
dtype_dict
=
{
dtype_dict
=
{
"bf16"
:
torch
.
bfloat16
,
"bf16"
:
torch
.
bfloat16
,
"fp8_e4m3"
:
torch
.
float8_e4m3fn
,
"fp8_e4m3"
:
torch
.
float8_e4m3fn
,
"fp8_e5m2"
:
torch
.
float8_e5m2
"fp8_e5m2"
:
torch
.
float8_e5m2
,
}
}
x
=
x
.
to
(
dtype_dict
[
a_dtype
]).
to
(
torch
.
bfloat16
)
x
=
x
.
to
(
dtype_dict
[
a_dtype
]).
to
(
torch
.
bfloat16
)
...
@@ -97,12 +106,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
...
@@ -97,12 +106,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
x_pad
=
w1_bottom_pad
x_pad
=
w1_bottom_pad
w1_tri
=
F
.
pad
(
w1_tri
,
(
0
,
w1_right_pad
,
0
,
w1_bottom_pad
,
0
,
0
),
w1_tri
=
F
.
pad
(
w1_tri
,
(
0
,
w1_right_pad
,
0
,
w1_bottom_pad
,
0
,
0
),
mode
=
"constant"
,
mode
=
"constant"
,
value
=
0
)
value
=
0
,
w2_tri
=
F
.
pad
(
w2_tri
,
(
0
,
w2_right_pad
,
0
,
w2_bottom_pad
,
0
,
0
),
)
w2_tri
=
F
.
pad
(
w2_tri
,
(
0
,
w2_right_pad
,
0
,
w2_bottom_pad
,
0
,
0
),
mode
=
"constant"
,
mode
=
"constant"
,
value
=
0
)
value
=
0
,
)
w1_bias_tri
=
F
.
pad
(
w1_bias_tri
,
(
0
,
w1_right_pad
,
0
,
0
),
w1_bias_tri
=
F
.
pad
(
w1_bias_tri
,
(
0
,
w1_right_pad
,
0
,
0
),
mode
=
"constant"
,
mode
=
"constant"
,
...
@@ -127,13 +142,19 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
...
@@ -127,13 +142,19 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
w1_tri
=
convert_layout
(
wrap_torch_tensor
(
w1_tri
,
FP4
),
w_layout
,
w1_tri
=
convert_layout
(
wrap_torch_tensor
(
w1_tri
,
FP4
),
w_layout
,
**
w_layout_opts
)
**
w_layout_opts
)
w1_scale_tri
=
convert_layout
(
wrap_torch_tensor
(
w1_scale_tri
),
w1_scale_tri
=
convert_layout
(
w_scale_layout
,
**
w_scale_layout_opts
)
wrap_torch_tensor
(
w1_scale_tri
),
w_scale_layout
,
**
w_scale_layout_opts
,
)
w2_tri
=
convert_layout
(
wrap_torch_tensor
(
w2_tri
,
FP4
),
w_layout
,
w2_tri
=
convert_layout
(
wrap_torch_tensor
(
w2_tri
,
FP4
),
w_layout
,
**
w_layout_opts
)
**
w_layout_opts
)
w2_scale_tri
=
convert_layout
(
wrap_torch_tensor
(
w2_scale_tri
),
w2_scale_tri
=
convert_layout
(
w_scale_layout
,
**
w_scale_layout_opts
)
wrap_torch_tensor
(
w2_scale_tri
),
w_scale_layout
,
**
w_scale_layout_opts
,
)
pc1
=
PrecisionConfig
(
weight_scale
=
w1_scale_tri
,
pc1
=
PrecisionConfig
(
weight_scale
=
w1_scale_tri
,
flex_ctx
=
FlexCtx
(
rhs_data
=
InFlexData
()))
flex_ctx
=
FlexCtx
(
rhs_data
=
InFlexData
()))
...
@@ -149,8 +170,22 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
...
@@ -149,8 +170,22 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
w1
=
w1
.
transpose
(
-
1
,
-
2
).
contiguous
()
w1
=
w1
.
transpose
(
-
1
,
-
2
).
contiguous
()
w2
=
w2
.
transpose
(
-
1
,
-
2
).
contiguous
()
w2
=
w2
.
transpose
(
-
1
,
-
2
).
contiguous
()
return
(
x
,
w1
,
w1_bias
,
w2
,
w2_bias
,
exp_data
,
x_tri
,
w1_tri
,
w2_tri
,
return
(
exp_data_tri
,
w1_bias_tri
,
w2_bias_tri
,
pc1
,
pc2
)
x
,
w1
,
w1_bias
,
w2
,
w2_bias
,
exp_data
,
x_tri
,
w1_tri
,
w2_tri
,
exp_data_tri
,
w1_bias_tri
,
w2_bias_tri
,
pc1
,
pc2
,
)
@
dataclass
@
dataclass
...
@@ -190,7 +225,8 @@ def oai_moe_forward(
...
@@ -190,7 +225,8 @@ def oai_moe_forward(
w2
:
torch
.
Tensor
,
# (E, K, N)
w2
:
torch
.
Tensor
,
# (E, K, N)
w2_bias
:
torch
.
Tensor
,
# (E, N)
w2_bias
:
torch
.
Tensor
,
# (E, N)
gating_output
:
torch
.
Tensor
,
# (M, E)
gating_output
:
torch
.
Tensor
,
# (M, E)
topk
:
int
):
topk
:
int
,
):
# model.py 309:330, assuming gating and norm
# model.py 309:330, assuming gating and norm
t
=
hidden_states
t
=
hidden_states
experts
=
torch
.
topk
(
gating_output
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
experts
=
torch
.
topk
(
gating_output
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
...
@@ -240,10 +276,22 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
...
@@ -240,10 +276,22 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
N
=
ModelConfig
.
intermediate_size
//
tp
N
=
ModelConfig
.
intermediate_size
//
tp
topk
=
ModelConfig
.
experts_per_token
topk
=
ModelConfig
.
experts_per_token
x
,
w1
,
w1_bias
,
w2
,
w2_bias
,
exp_data
,
\
(
x_tri
,
w1_tri
,
w2_tri
,
exp_data_tri
,
w1_bias_tri
,
\
x
,
w2_bias_tri
,
pc1
,
pc2
=
init_compute_data
(
w1
,
M
,
K
,
N
,
E
,
a_dtype
,
w_dtype
,
num_warps
=
8
)
w1_bias
,
w2
,
w2_bias
,
exp_data
,
x_tri
,
w1_tri
,
w2_tri
,
exp_data_tri
,
w1_bias_tri
,
w2_bias_tri
,
pc1
,
pc2
,
)
=
init_compute_data
(
M
,
K
,
N
,
E
,
a_dtype
,
w_dtype
,
num_warps
=
8
)
out_triton_monolithic
=
triton_kernel_moe_forward
(
out_triton_monolithic
=
triton_kernel_moe_forward
(
hidden_states
=
x_tri
,
hidden_states
=
x_tri
,
...
@@ -255,33 +303,46 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
...
@@ -255,33 +303,46 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
w1_bias
=
w1_bias_tri
,
w1_bias
=
w1_bias_tri
,
w2_bias
=
w2_bias_tri
,
w2_bias
=
w2_bias_tri
,
w1_precision
=
pc1
,
w1_precision
=
pc1
,
w2_precision
=
pc2
)
w2_precision
=
pc2
,
)
out_triton_monolithic
=
out_triton_monolithic
[...,
:
K
]
out_triton_monolithic
=
out_triton_monolithic
[...,
:
K
]
out_ref
=
oai_moe_forward
(
hidden_states
=
x
,
out_ref
=
oai_moe_forward
(
hidden_states
=
x
,
w1
=
w1
,
w1
=
w1
,
w1_bias
=
w1_bias
,
w1_bias
=
w1_bias
,
w2
=
w2
,
w2
=
w2
,
w2_bias
=
w2_bias
,
w2_bias
=
w2_bias
,
gating_output
=
exp_data
,
gating_output
=
exp_data
,
topk
=
topk
)
topk
=
topk
,
)
assert_close
(
ref
=
out_ref
,
assert_close
(
ref
=
out_ref
,
tri
=
out_triton_monolithic
,
tri
=
out_triton_monolithic
,
maxtol
=
0.025
,
maxtol
=
0.025
,
rmstol
=
0.005
)
rmstol
=
0.005
)
def
batched_moe
(
a
:
torch
.
Tensor
,
w1
,
w2
,
gating_output
:
torch
.
Tensor
,
def
batched_moe
(
topk
:
int
,
renormalize
:
bool
,
w1_bias
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
w2_bias
:
torch
.
Tensor
,
w1_precision
:
PrecisionConfig
,
w1
,
w2_precision
:
PrecisionConfig
)
->
torch
.
Tensor
:
w2
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
w1_bias
:
torch
.
Tensor
,
w2_bias
:
torch
.
Tensor
,
w1_precision
:
PrecisionConfig
,
w2_precision
:
PrecisionConfig
,
)
->
torch
.
Tensor
:
max_num_tokens
=
round_up
(
a
.
shape
[
0
],
64
)
max_num_tokens
=
round_up
(
a
.
shape
[
0
],
64
)
fused_experts
=
FusedMoEModularKernel
(
fused_experts
=
FusedMoEModularKernel
(
BatchedPrepareAndFinalize
(
max_num_tokens
,
BatchedPrepareAndFinalize
(
max_num_tokens
,
num_dispatchers
=
1
,
num_dispatchers
=
1
,
num_local_experts
=
w1
.
shape
[
0
],
num_local_experts
=
w1
.
shape
[
0
],
rank
=
0
),
rank
=
0
,
),
BatchedOAITritonExperts
(
BatchedOAITritonExperts
(
None
,
None
,
max_num_tokens
=
max_num_tokens
,
max_num_tokens
=
max_num_tokens
,
...
@@ -327,12 +388,25 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
...
@@ -327,12 +388,25 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
N
=
ModelConfig
.
intermediate_size
N
=
ModelConfig
.
intermediate_size
topk
=
ModelConfig
.
experts_per_token
topk
=
ModelConfig
.
experts_per_token
x
,
w1
,
w1_bias
,
w2
,
w2_bias
,
exp_data
,
\
(
x_tri
,
w1_tri
,
w2_tri
,
exp_data_tri
,
w1_bias_tri
,
\
x
,
w2_bias_tri
,
pc1
,
pc2
=
init_compute_data
(
w1
,
M
,
K
,
N
,
E
,
a_dtype
,
w_dtype
,
num_warps
=
4
)
w1_bias
,
w2
,
out_tri
=
batched_moe
(
a
=
x_tri
,
w2_bias
,
exp_data
,
x_tri
,
w1_tri
,
w2_tri
,
exp_data_tri
,
w1_bias_tri
,
w2_bias_tri
,
pc1
,
pc2
,
)
=
init_compute_data
(
M
,
K
,
N
,
E
,
a_dtype
,
w_dtype
,
num_warps
=
4
)
out_tri
=
batched_moe
(
a
=
x_tri
,
w1
=
w1_tri
,
w1
=
w1_tri
,
w2
=
w2_tri
,
w2
=
w2_tri
,
gating_output
=
exp_data_tri
,
gating_output
=
exp_data_tri
,
...
@@ -341,16 +415,19 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
...
@@ -341,16 +415,19 @@ def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
w1_bias
=
w1_bias_tri
,
w1_bias
=
w1_bias_tri
,
w2_bias
=
w2_bias_tri
,
w2_bias
=
w2_bias_tri
,
w1_precision
=
pc1
,
w1_precision
=
pc1
,
w2_precision
=
pc2
)
w2_precision
=
pc2
,
)
out_tri
=
out_tri
[...,
:
K
]
out_tri
=
out_tri
[...,
:
K
]
out_ref
=
oai_moe_forward
(
hidden_states
=
x
,
out_ref
=
oai_moe_forward
(
hidden_states
=
x
,
w1
=
w1
,
w1
=
w1
,
w1_bias
=
w1_bias
,
w1_bias
=
w1_bias
,
w2
=
w2
,
w2
=
w2
,
w2_bias
=
w2_bias
,
w2_bias
=
w2_bias
,
gating_output
=
exp_data
,
gating_output
=
exp_data
,
topk
=
topk
)
topk
=
topk
,
)
assert_close
(
ref
=
out_ref
,
tri
=
out_tri
,
maxtol
=
0.025
,
rmstol
=
0.005
)
assert_close
(
ref
=
out_ref
,
tri
=
out_tri
,
maxtol
=
0.025
,
rmstol
=
0.005
)
...
@@ -370,6 +447,7 @@ def test_unit_shuffle():
...
@@ -370,6 +447,7 @@ def test_unit_shuffle():
out
=
triton_kernels
.
swiglu
.
swiglu_torch
(
out
=
triton_kernels
.
swiglu
.
swiglu_torch
(
out
,
out
,
alpha
=
1.702
,
alpha
=
1.702
,
precision_config
=
triton_kernels
.
swiglu
.
PrecisionConfig
(
limit
=
1.0
))
precision_config
=
triton_kernels
.
swiglu
.
PrecisionConfig
(
limit
=
1.0
),
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment