Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2df2c85b
Unverified
Commit
2df2c85b
authored
Apr 06, 2026
by
Andreas Karatzas
Committed by
GitHub
Apr 07, 2026
Browse files
[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
62095e82
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
84 additions
and
216 deletions
+84
-216
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml
+1
-1
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
..._oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
+2
-2
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
...oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
+1
-1
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
..._oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
+1
-1
tests/kernels/moe/test_gpt_oss_triton_kernels.py
tests/kernels/moe/test_gpt_oss_triton_kernels.py
+12
-57
tests/kernels/quantization/test_mxfp4_triton_ep.py
tests/kernels/quantization/test_mxfp4_triton_ep.py
+18
-41
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+49
-113
No files found.
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml
View file @
2df2c85b
...
@@ -3,4 +3,4 @@
...
@@ -3,4 +3,4 @@
model_name
:
openai/gpt-oss-20b
model_name
:
openai/gpt-oss-20b
metric_threshold
:
0.568
metric_threshold
:
0.568
reasoning_effort
:
low
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN"
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--tensor-parallel-size
2"
\ No newline at end of file
\ No newline at end of file
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
View file @
2df2c85b
...
@@ -3,6 +3,6 @@
...
@@ -3,6 +3,6 @@
model_name
:
amd/gpt-oss-20b-w-mxfp4-a-bf16
model_name
:
amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold
:
0.568
metric_threshold
:
0.568
reasoning_effort
:
low
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
aiter"
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
aiter
--tokenizer
openai/gpt-oss-20b
--tensor-parallel-size
2
"
env
:
env
:
VLLM_ROCM_USE_AITER
:
"
1"
VLLM_ROCM_USE_AITER
:
"
1"
\ No newline at end of file
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
View file @
2df2c85b
...
@@ -3,4 +3,4 @@
...
@@ -3,4 +3,4 @@
model_name
:
amd/gpt-oss-20b-w-mxfp4-a-bf16
model_name
:
amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold
:
0.568
metric_threshold
:
0.568
reasoning_effort
:
low
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
triton"
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
triton
--tokenizer
openai/gpt-oss-20b
--tensor-parallel-size
2"
\ No newline at end of file
\ No newline at end of file
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
View file @
2df2c85b
...
@@ -3,6 +3,6 @@
...
@@ -3,6 +3,6 @@
model_name
:
amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
model_name
:
amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
metric_threshold
:
0.568
metric_threshold
:
0.568
reasoning_effort
:
low
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN"
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--tensor-parallel-size
2
"
env
:
env
:
VLLM_ROCM_USE_AITER
:
"
1"
VLLM_ROCM_USE_AITER
:
"
1"
\ No newline at end of file
tests/kernels/moe/test_gpt_oss_triton_kernels.py
View file @
2df2c85b
...
@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
...
@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from
triton_kernels.tensor
import
FP4
,
convert_layout
,
wrap_torch_tensor
from
triton_kernels.tensor
import
FP4
,
convert_layout
,
wrap_torch_tensor
from
triton_kernels.tensor_details
import
layout
from
triton_kernels.tensor_details
import
layout
from
triton_kernels.testing
import
assert_close
from
triton_kernels.testing
import
assert_close
from
triton_kernels.topk
import
topk
as
topk_fn
from
vllm.model_executor.layers.fused_moe.config
import
mxfp4_w4a16_moe_quant_config
from
vllm.model_executor.layers.fused_moe.config
import
mxfp4_w4a16_moe_quant_config
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
legacy_routing
,
make_routing_data
,
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
)
)
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
set_random_seed
from
.utils
import
shuffle_weight
from
.utils
import
shuffle_weight
...
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
...
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
if
w_dtype
!=
"mx4"
:
if
w_dtype
!=
"mx4"
:
pytest
.
skip
(
"NYI"
)
pytest
.
skip
(
"NYI"
)
else
:
# quantize to mx4
else
:
# quantize to mx4
# careful on the padding here, the activation padding need to be
# Padding alignment depends on the platform. On CDNA4 the scale
# multiple of 64, the actual engine is not implemented
# swizzle requires SCALE_K % 8 == 0 (K % 256) and
w1_bottom_pad
=
round_up
(
w1_tri
.
shape
[
1
],
64
)
-
w1_tri
.
shape
[
1
]
# SCALE_N % 32 == 0 (2*N % 512), matching the production
w1_right_pad
=
round_up
(
w1_tri
.
shape
[
2
],
128
)
-
w1_tri
.
shape
[
2
]
# alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
# On CUDA (Hopper) the scale layout pads internally, so the
# original 64/128 alignment is sufficient.
if
current_platform
.
is_rocm
():
k_align
,
n2_align
=
256
,
512
else
:
k_align
,
n2_align
=
64
,
128
w1_bottom_pad
=
round_up
(
w1_tri
.
shape
[
1
],
k_align
)
-
w1_tri
.
shape
[
1
]
w1_right_pad
=
round_up
(
w1_tri
.
shape
[
2
],
n2_align
)
-
w1_tri
.
shape
[
2
]
w2_bottom_pad
=
w1_right_pad
//
2
w2_bottom_pad
=
w1_right_pad
//
2
w2_right_pad
=
w1_bottom_pad
w2_right_pad
=
w1_bottom_pad
...
@@ -367,52 +371,3 @@ def test_unit_shuffle():
...
@@ -367,52 +371,3 @@ def test_unit_shuffle():
)
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_legacy_routing
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
dtype
:
torch
.
dtype
):
set_random_seed
(
0
)
gating_output
=
torch
.
randn
(
num_tokens
,
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
topk_result
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if
isinstance
(
topk_result
,
tuple
):
topk_weights
,
topk_ids_raw
,
bitmatrix
=
topk_result
from
triton_kernels.routing
import
routing_from_bitmatrix
routing_data_ref
,
gather_indx_ref
,
scatter_indx_ref
=
routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids_raw
,
num_experts
,
topk
)
else
:
topk_ids
=
topk_result
.
indx
.
to
(
torch
.
long
)
topk_weights
=
topk_result
.
vals
routing_data_ref
,
gather_indx_ref
,
scatter_indx_ref
=
make_routing_data
(
topk_ids
,
topk_weights
,
num_experts
)
routing_data
,
gather_indx
,
scatter_indx
=
legacy_routing
(
gating_output
,
topk
,
sm_first
=
sm_first
)
assert_close
(
ref
=
gather_indx_ref
.
src_indx
,
tri
=
gather_indx
.
src_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
gather_indx_ref
.
dst_indx
,
tri
=
gather_indx
.
dst_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
scatter_indx_ref
.
src_indx
,
tri
=
scatter_indx
.
src_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
scatter_indx_ref
.
dst_indx
,
tri
=
scatter_indx
.
dst_indx
,
maxtol
=
0
,
rmstol
=
0
)
tests/kernels/quantization/test_mxfp4_triton_ep.py
View file @
2df2c85b
...
@@ -4,12 +4,9 @@
...
@@ -4,12 +4,9 @@
Tests that triton_kernel_moe_forward correctly applies expert_map
Tests that triton_kernel_moe_forward correctly applies expert_map
remapping when expert parallelism (EP) is enabled.
remapping when expert parallelism (EP) is enabled.
Previously, legacy_routing was always used and it produced routing data
Both EP and non-EP paths use topk + make_routing_data. When expert_map
with global expert IDs that didn't correspond to local weight indices,
is provided, global expert IDs are remapped to local IDs before building
causing illegal memory access with EP. The fix splits routing: when
routing structures.
expert_map is provided, topk selection is performed first, expert_map is
applied to remap global→local IDs, and make_routing_data builds routing
structures from the local IDs.
"""
"""
from
unittest.mock
import
MagicMock
,
patch
from
unittest.mock
import
MagicMock
,
patch
...
@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
...
@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
@
pytest
.
mark
.
parametrize
(
"expert_map_present"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"expert_map_present"
,
[
False
,
True
])
def
test_routing_path_selection
(
self
,
expert_map_present
):
def
test_routing_path_selection
(
self
,
expert_map_present
):
"""Verify that th
e
EP
-aware routing path is taken when expert_map
"""Verify that
bo
th EP
and non-EP paths use topk + make_routing_data,
is present, and the legacy_routing path is taken otherwise
."""
and that expert_map remapping is applied when present
."""
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
# This is a structural test: we mock the routing functions to
# verify the correct path is exercised.
mock_expert_map
=
(
mock_expert_map
=
(
torch
.
tensor
([
0
,
-
1
,
1
,
-
1
],
device
=
device
)
if
expert_map_present
else
None
torch
.
tensor
([
0
,
-
1
,
1
,
-
1
],
device
=
device
)
if
expert_map_present
else
None
)
)
with
(
with
(
patch
(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.legacy_routing"
)
as
mock_legacy
,
patch
(
"triton_kernels.topk.topk"
)
as
mock_topk
,
patch
(
"triton_kernels.topk.topk"
)
as
mock_topk
,
patch
(
patch
(
"vllm.model_executor.layers.fused_moe."
"vllm.model_executor.layers.fused_moe."
...
@@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap:
...
@@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap:
triton_kernel_moe_forward
,
triton_kernel_moe_forward
,
)
)
# Set up return values
mock_routing_data
=
MagicMock
()
mock_routing_data
=
MagicMock
()
mock_gather
=
MagicMock
()
mock_gather
=
MagicMock
()
mock_scatter
=
MagicMock
()
mock_scatter
=
MagicMock
()
if
expert_map_present
:
sparse_result
=
MagicMock
()
sparse_result
=
MagicMock
()
sparse_result
.
indx
=
torch
.
tensor
([[
0
,
2
]],
dtype
=
torch
.
int32
)
sparse_result
.
indx
=
torch
.
tensor
([[
0
,
2
]],
dtype
=
torch
.
int32
)
sparse_result
.
vals
=
torch
.
tensor
([[
0.6
,
0.4
]])
sparse_result
.
vals
=
torch
.
tensor
([[
0.6
,
0.4
]])
mock_topk
.
return_value
=
sparse_result
mock_topk
.
return_value
=
sparse_result
mock_make_routing
.
return_value
=
(
mock_make_routing
.
return_value
=
(
mock_routing_data
,
mock_routing_data
,
mock_gather
,
mock_gather
,
mock_scatter
,
mock_scatter
,
)
)
else
:
mock_legacy
.
return_value
=
(
mock_routing_data
,
mock_gather
,
mock_scatter
,
)
mock_fused_experts
.
return_value
=
torch
.
zeros
((
1
,
8
),
device
=
device
)
mock_fused_experts
.
return_value
=
torch
.
zeros
((
1
,
8
),
device
=
device
)
...
@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
...
@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
expert_map
=
mock_expert_map
,
expert_map
=
mock_expert_map
,
)
)
# Both paths use topk + make_routing_data
mock_topk
.
assert_called_once
()
mock_make_routing
.
assert_called_once
()
if
expert_map_present
:
if
expert_map_present
:
# EP path: should use topk + make_routing_data, NOT
# legacy_routing
mock_topk
.
assert_called_once
()
mock_make_routing
.
assert_called_once
()
mock_legacy
.
assert_not_called
()
# expert_map should be None in the fused_experts call
# expert_map should be None in the fused_experts call
# (already applied)
# (already applied)
call_kwargs
=
mock_fused_experts
.
call_args
call_kwargs
=
mock_fused_experts
.
call_args
assert
call_kwargs
[
1
].
get
(
"expert_map"
)
is
None
or
(
assert
call_kwargs
[
1
].
get
(
"expert_map"
)
is
None
or
(
len
(
call_kwargs
[
0
])
>
0
len
(
call_kwargs
[
0
])
>
0
)
)
else
:
# Non-EP path: should use legacy_routing
mock_legacy
.
assert_called_once
()
mock_topk
.
assert_not_called
()
mock_make_routing
.
assert_not_called
()
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
2df2c85b
...
@@ -47,7 +47,6 @@ if has_triton_kernels():
...
@@ -47,7 +47,6 @@ if has_triton_kernels():
BIT
,
BIT
,
Bitmatrix
,
Bitmatrix
,
)
)
from
triton_kernels.topk
import
topk
try
:
try
:
from
triton_kernels.tensor
import
(
from
triton_kernels.tensor
import
(
...
@@ -89,6 +88,7 @@ def pack_bitmatrix(
...
@@ -89,6 +88,7 @@ def pack_bitmatrix(
offsets
=
offsets_m
[:,
None
]
*
n_expts_act
+
offsets_k
[
None
,
:]
offsets
=
offsets_m
[:,
None
]
*
n_expts_act
+
offsets_k
[
None
,
:]
mask
=
(
offsets_m
<
n_rows
)[:,
None
]
&
(
offsets_k
<
n_expts_act
)[
None
,
:]
mask
=
(
offsets_m
<
n_rows
)[:,
None
]
&
(
offsets_k
<
n_expts_act
)[
None
,
:]
indices
=
tl
.
load
(
topk_ids
+
offsets
,
mask
=
mask
,
other
=-
1
)
indices
=
tl
.
load
(
topk_ids
+
offsets
,
mask
=
mask
,
other
=-
1
)
valid
=
indices
>=
0
div
=
indices
//
32
div
=
indices
//
32
rem
=
indices
%
32
rem
=
indices
%
32
one
=
tl
.
cast
(
1
,
tl
.
uint32
)
one
=
tl
.
cast
(
1
,
tl
.
uint32
)
...
@@ -99,8 +99,13 @@ def pack_bitmatrix(
...
@@ -99,8 +99,13 @@ def pack_bitmatrix(
offs
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
32
)
+
i
*
(
BLOCK_SIZE_K
//
32
)
offs
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
32
)
+
i
*
(
BLOCK_SIZE_K
//
32
)
# All topks that need to go into this column has the correct bit set.
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
# Other bits are 0. x is a 2D tensor.
# Guard with `valid` to prevent negative indices from producing
# spurious bits (on HIP, -1 // 32 == 0 and 1 << (-1 % 32) sets
# bit 31).
x
=
tl
.
where
(
x
=
tl
.
where
(
div
[:,
:,
None
]
==
offs
[
None
,
None
,
:],
(
one
<<
rem
)[:,
:,
None
],
0
valid
[:,
:,
None
]
&
(
div
[:,
:,
None
]
==
offs
[
None
,
None
,
:]),
(
one
<<
rem
)[:,
:,
None
],
0
,
)
)
# Reduce x to get a single int32_t bitpack.
# Reduce x to get a single int32_t bitpack.
y
=
tl
.
reduce_or
(
x
,
axis
=
1
)
y
=
tl
.
reduce_or
(
x
,
axis
=
1
)
...
@@ -108,93 +113,6 @@ def pack_bitmatrix(
...
@@ -108,93 +113,6 @@ def pack_bitmatrix(
tl
.
store
(
bitmatrix_ptrs
,
y
,
mask
=
offsets_m
[:,
None
]
<
n_rows
)
tl
.
store
(
bitmatrix_ptrs
,
y
,
mask
=
offsets_m
[:,
None
]
<
n_rows
)
def
legacy_routing_from_bitmatrix
(
bitmatrix
:
"Bitmatrix"
,
expt_scal
:
torch
.
Tensor
,
expt_indx
:
torch
.
Tensor
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if
use_legacy_triton_kernels
:
from
triton_kernels.routing
import
routing_from_bitmatrix
return
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
)
sparse_logits
=
SparseMatrix
(
indx
=
expt_indx
,
vals
=
expt_scal
,
mask
=
bitmatrix
)
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing_from_sparsematrix
(
sparse_logits
:
"SparseMatrix"
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing
(
logits
:
torch
.
Tensor
,
n_expts_act
:
int
,
sm_first
:
bool
=
False
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if
use_legacy_triton_kernels
:
from
triton_kernels.routing
import
routing
return
routing
(
logits
,
n_expts_act
,
sm_first
=
sm_first
)
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sparse_logits
=
topk
(
logits
,
n_expts_act
,
apply_softmax
=
not
sm_first
)
return
legacy_routing_from_sparsematrix
(
sparse_logits
,
logits
.
shape
[
-
1
],
n_expts_act
,
)
def
triton_kernel_moe_forward
(
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
,
# Tensor or triton_kernels.Tensor
w1
,
# Tensor or triton_kernels.Tensor
...
@@ -241,26 +159,22 @@ def triton_kernel_moe_forward(
...
@@ -241,26 +159,22 @@ def triton_kernel_moe_forward(
unpadded_K_w2
=
unpadded_K_w2
,
unpadded_K_w2
=
unpadded_K_w2
,
)
)
from
triton_kernels.topk
import
topk
as
topk_fn
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
topk_result
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if
isinstance
(
topk_result
,
tuple
):
topk_weights
,
topk_ids_raw
,
_
=
topk_result
else
:
topk_weights
=
topk_result
.
vals
topk_ids_raw
=
topk_result
.
indx
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
# indices. Split the routing into topk selection + expert_map
# remapping + local routing data construction (matching the
# approach used by OAITritonExperts.apply).
from
triton_kernels.topk
import
topk
as
topk_fn
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
topk_result
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if
isinstance
(
topk_result
,
tuple
):
topk_weights
,
topk_ids_raw
,
_
=
topk_result
else
:
topk_weights
=
topk_result
.
vals
topk_ids_raw
=
topk_result
.
indx
# topk_ids_raw contains global expert IDs - remap to local.
# topk_ids_raw contains global expert IDs - remap to local.
topk_ids
=
expert_map
[
topk_ids_raw
.
to
(
torch
.
long
)]
topk_ids
=
expert_map
[
topk_ids_raw
.
to
(
torch
.
long
)]
local_num_experts
=
w1
.
shape
[
0
]
local_num_experts
=
w1
.
shape
[
0
]
...
@@ -271,8 +185,9 @@ def triton_kernel_moe_forward(
...
@@ -271,8 +185,9 @@ def triton_kernel_moe_forward(
effective_expert_map
=
None
effective_expert_map
=
None
effective_global_num_experts
=
local_num_experts
effective_global_num_experts
=
local_num_experts
else
:
else
:
routing_data
,
gather_idx
,
scatter_idx
=
legacy_routing
(
topk_ids
=
topk_ids_raw
.
to
(
torch
.
long
)
gating_output
,
topk
,
sm_first
=
not
renormalize
routing_data
,
gather_idx
,
scatter_idx
=
make_routing_data
(
topk_ids
,
topk_weights
,
gating_output
.
shape
[
-
1
]
)
)
effective_expert_map
=
expert_map
effective_expert_map
=
expert_map
effective_global_num_experts
=
global_num_experts
effective_global_num_experts
=
global_num_experts
...
@@ -539,10 +454,31 @@ def make_routing_data(
...
@@ -539,10 +454,31 @@ def make_routing_data(
# matmul_ogs expects invalid topk_weights to be -1s
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights
=
torch
.
where
(
topk_ids
==
-
1
,
-
1.0
,
topk_weights
)
topk_weights
=
torch
.
where
(
topk_ids
==
-
1
,
-
1.0
,
topk_weights
)
routing_data
,
gather_indx
,
scatter_indx
=
legacy_routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids
,
num_local_experts
,
num_topk
)
if
use_legacy_triton_kernels
:
from
triton_kernels.routing
import
routing_from_bitmatrix
return
routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids
,
num_local_experts
,
num_topk
)
sparse_logits
=
SparseMatrix
(
indx
=
topk_ids
,
vals
=
topk_weights
,
mask
=
bitmatrix
)
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
num_local_experts
,
num_topk
,
ragged_batch_metadata
,
)
gather_indx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_indx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_indx
,
scatter_indx
return
routing_data
,
gather_indx
,
scatter_indx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment