Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d0532bf3
Unverified
Commit
d0532bf3
authored
Mar 20, 2026
by
Xin Yang
Committed by
GitHub
Mar 20, 2026
Browse files
[Perf] Eliminate redundant SparseMatrix creation in gpt_oss_triton_kernels (#37683)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
fb4e8bf4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
4 deletions
+73
-4
tests/kernels/moe/test_gpt_oss_triton_kernels.py
tests/kernels/moe/test_gpt_oss_triton_kernels.py
+44
-0
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+29
-4
No files found.
tests/kernels/moe/test_gpt_oss_triton_kernels.py
View file @
d0532bf3
...
@@ -21,12 +21,16 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
...
@@ -21,12 +21,16 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from
triton_kernels.tensor
import
FP4
,
convert_layout
,
wrap_torch_tensor
from
triton_kernels.tensor
import
FP4
,
convert_layout
,
wrap_torch_tensor
from
triton_kernels.tensor_details
import
layout
from
triton_kernels.tensor_details
import
layout
from
triton_kernels.testing
import
assert_close
from
triton_kernels.testing
import
assert_close
from
triton_kernels.topk
import
topk
as
topk_fn
from
vllm.model_executor.layers.fused_moe.config
import
mxfp4_w4a16_moe_quant_config
from
vllm.model_executor.layers.fused_moe.config
import
mxfp4_w4a16_moe_quant_config
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
legacy_routing
,
make_routing_data
,
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
)
)
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
set_random_seed
from
.utils
import
shuffle_weight
from
.utils
import
shuffle_weight
...
@@ -355,3 +359,43 @@ def test_unit_shuffle():
...
@@ -355,3 +359,43 @@ def test_unit_shuffle():
)
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_legacy_routing
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
dtype
:
torch
.
dtype
):
set_random_seed
(
0
)
gating_output
=
torch
.
randn
(
num_tokens
,
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sparse_logits
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
topk_ids
=
sparse_logits
.
indx
.
to
(
torch
.
long
)
topk_weights
=
sparse_logits
.
vals
routing_data_ref
,
gather_indx_ref
,
scatter_indx_ref
=
make_routing_data
(
topk_ids
,
topk_weights
,
num_experts
)
routing_data
,
gather_indx
,
scatter_indx
=
legacy_routing
(
gating_output
,
topk
,
sm_first
=
sm_first
)
assert_close
(
ref
=
gather_indx_ref
.
src_indx
,
tri
=
gather_indx
.
src_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
gather_indx_ref
.
dst_indx
,
tri
=
gather_indx
.
dst_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
scatter_indx_ref
.
src_indx
,
tri
=
scatter_indx
.
src_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
scatter_indx_ref
.
dst_indx
,
tri
=
scatter_indx
.
dst_indx
,
maxtol
=
0
,
rmstol
=
0
)
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
d0532bf3
...
@@ -142,6 +142,33 @@ def legacy_routing_from_bitmatrix(
...
@@ -142,6 +142,33 @@ def legacy_routing_from_bitmatrix(
return
routing_data
,
gather_idx
,
scatter_idx
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing_from_sparsematrix
(
sparse_logits
:
"SparseMatrix"
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing
(
def
legacy_routing
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
n_expts_act
:
int
,
n_expts_act
:
int
,
...
@@ -158,10 +185,8 @@ def legacy_routing(
...
@@ -158,10 +185,8 @@ def legacy_routing(
if
sm_first
:
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sparse_logits
=
topk
(
logits
,
n_expts_act
,
apply_softmax
=
not
sm_first
)
sparse_logits
=
topk
(
logits
,
n_expts_act
,
apply_softmax
=
not
sm_first
)
return
legacy_routing_from_bitmatrix
(
return
legacy_routing_from_sparsematrix
(
sparse_logits
.
mask
,
sparse_logits
,
sparse_logits
.
vals
,
sparse_logits
.
indx
,
logits
.
shape
[
-
1
],
logits
.
shape
[
-
1
],
n_expts_act
,
n_expts_act
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment