Unverified Commit 2df2c85b authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 62095e82
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
model_name: openai/gpt-oss-20b model_name: openai/gpt-oss-20b
metric_threshold: 0.568 metric_threshold: 0.568
reasoning_effort: low reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN" server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2"
\ No newline at end of file \ No newline at end of file
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16 model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568 metric_threshold: 0.568
reasoning_effort: low reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter" server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2"
env: env:
VLLM_ROCM_USE_AITER: "1" VLLM_ROCM_USE_AITER: "1"
\ No newline at end of file
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16 model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568 metric_threshold: 0.568
reasoning_effort: low reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton" server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton --tokenizer openai/gpt-oss-20b --tensor-parallel-size 2"
\ No newline at end of file \ No newline at end of file
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
model_name: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8 model_name: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
metric_threshold: 0.568 metric_threshold: 0.568
reasoning_effort: low reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN" server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --tensor-parallel-size 2"
env: env:
VLLM_ROCM_USE_AITER: "1" VLLM_ROCM_USE_AITER: "1"
\ No newline at end of file
...@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m ...@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close from triton_kernels.testing import assert_close
from triton_kernels.topk import topk as topk_fn
from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
legacy_routing,
make_routing_data,
triton_kernel_moe_forward, triton_kernel_moe_forward,
) )
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import set_random_seed
from .utils import shuffle_weight from .utils import shuffle_weight
...@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int): ...@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
if w_dtype != "mx4": if w_dtype != "mx4":
pytest.skip("NYI") pytest.skip("NYI")
else: # quantize to mx4 else: # quantize to mx4
# careful on the padding here, the activation padding need to be # Padding alignment depends on the platform. On CDNA4 the scale
# multiple of 64, the actual engine is not implemented # swizzle requires SCALE_K % 8 == 0 (K % 256) and
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1] # SCALE_N % 32 == 0 (2*N % 512), matching the production
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2] # alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
# On CUDA (Hopper) the scale layout pads internally, so the
# original 64/128 alignment is sufficient.
if current_platform.is_rocm():
k_align, n2_align = 256, 512
else:
k_align, n2_align = 64, 128
w1_bottom_pad = round_up(w1_tri.shape[1], k_align) - w1_tri.shape[1]
w1_right_pad = round_up(w1_tri.shape[2], n2_align) - w1_tri.shape[2]
w2_bottom_pad = w1_right_pad // 2 w2_bottom_pad = w1_right_pad // 2
w2_right_pad = w1_bottom_pad w2_right_pad = w1_bottom_pad
...@@ -367,52 +371,3 @@ def test_unit_shuffle(): ...@@ -367,52 +371,3 @@ def test_unit_shuffle():
) )
assert_close(ref=out_ref, tri=out) assert_close(ref=out_ref, tri=out)
@pytest.mark.parametrize("num_tokens", [2, 8, 64])
@pytest.mark.parametrize("num_experts", [32, 128])
@pytest.mark.parametrize("topk", [1, 4])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_legacy_routing(
num_tokens: int, num_experts: int, topk: int, renormalize: bool, dtype: torch.dtype
):
set_random_seed(0)
gating_output = torch.randn(num_tokens, num_experts, device="cuda", dtype=dtype)
sm_first = not renormalize
logits = gating_output
if sm_first:
logits = torch.softmax(logits, dim=-1)
topk_result = topk_fn(logits, topk, apply_softmax=not sm_first)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if isinstance(topk_result, tuple):
topk_weights, topk_ids_raw, bitmatrix = topk_result
from triton_kernels.routing import routing_from_bitmatrix
routing_data_ref, gather_indx_ref, scatter_indx_ref = routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids_raw, num_experts, topk
)
else:
topk_ids = topk_result.indx.to(torch.long)
topk_weights = topk_result.vals
routing_data_ref, gather_indx_ref, scatter_indx_ref = make_routing_data(
topk_ids, topk_weights, num_experts
)
routing_data, gather_indx, scatter_indx = legacy_routing(
gating_output, topk, sm_first=sm_first
)
assert_close(
ref=gather_indx_ref.src_indx, tri=gather_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=gather_indx_ref.dst_indx, tri=gather_indx.dst_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.src_indx, tri=scatter_indx.src_indx, maxtol=0, rmstol=0
)
assert_close(
ref=scatter_indx_ref.dst_indx, tri=scatter_indx.dst_indx, maxtol=0, rmstol=0
)
...@@ -4,12 +4,9 @@ ...@@ -4,12 +4,9 @@
Tests that triton_kernel_moe_forward correctly applies expert_map Tests that triton_kernel_moe_forward correctly applies expert_map
remapping when expert parallelism (EP) is enabled. remapping when expert parallelism (EP) is enabled.
Previously, legacy_routing was always used and it produced routing data Both EP and non-EP paths use topk + make_routing_data. When expert_map
with global expert IDs that didn't correspond to local weight indices, is provided, global expert IDs are remapped to local IDs before building
causing illegal memory access with EP. The fix splits routing: when routing structures.
expert_map is provided, topk selection is performed first, expert_map is
applied to remap global→local IDs, and make_routing_data builds routing
structures from the local IDs.
""" """
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
...@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap: ...@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
@pytest.mark.parametrize("expert_map_present", [False, True]) @pytest.mark.parametrize("expert_map_present", [False, True])
def test_routing_path_selection(self, expert_map_present): def test_routing_path_selection(self, expert_map_present):
"""Verify that the EP-aware routing path is taken when expert_map """Verify that both EP and non-EP paths use topk + make_routing_data,
is present, and the legacy_routing path is taken otherwise.""" and that expert_map remapping is applied when present."""
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
# This is a structural test: we mock the routing functions to
# verify the correct path is exercised.
mock_expert_map = ( mock_expert_map = (
torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None torch.tensor([0, -1, 1, -1], device=device) if expert_map_present else None
) )
with ( with (
patch(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.legacy_routing"
) as mock_legacy,
patch("triton_kernels.topk.topk") as mock_topk, patch("triton_kernels.topk.topk") as mock_topk,
patch( patch(
"vllm.model_executor.layers.fused_moe." "vllm.model_executor.layers.fused_moe."
...@@ -53,12 +44,10 @@ class TestTritonMoeForwardExpertMap: ...@@ -53,12 +44,10 @@ class TestTritonMoeForwardExpertMap:
triton_kernel_moe_forward, triton_kernel_moe_forward,
) )
# Set up return values
mock_routing_data = MagicMock() mock_routing_data = MagicMock()
mock_gather = MagicMock() mock_gather = MagicMock()
mock_scatter = MagicMock() mock_scatter = MagicMock()
if expert_map_present:
sparse_result = MagicMock() sparse_result = MagicMock()
sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32) sparse_result.indx = torch.tensor([[0, 2]], dtype=torch.int32)
sparse_result.vals = torch.tensor([[0.6, 0.4]]) sparse_result.vals = torch.tensor([[0.6, 0.4]])
...@@ -68,12 +57,6 @@ class TestTritonMoeForwardExpertMap: ...@@ -68,12 +57,6 @@ class TestTritonMoeForwardExpertMap:
mock_gather, mock_gather,
mock_scatter, mock_scatter,
) )
else:
mock_legacy.return_value = (
mock_routing_data,
mock_gather,
mock_scatter,
)
mock_fused_experts.return_value = torch.zeros((1, 8), device=device) mock_fused_experts.return_value = torch.zeros((1, 8), device=device)
...@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap: ...@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
expert_map=mock_expert_map, expert_map=mock_expert_map,
) )
if expert_map_present: # Both paths use topk + make_routing_data
# EP path: should use topk + make_routing_data, NOT
# legacy_routing
mock_topk.assert_called_once() mock_topk.assert_called_once()
mock_make_routing.assert_called_once() mock_make_routing.assert_called_once()
mock_legacy.assert_not_called()
if expert_map_present:
# expert_map should be None in the fused_experts call # expert_map should be None in the fused_experts call
# (already applied) # (already applied)
call_kwargs = mock_fused_experts.call_args call_kwargs = mock_fused_experts.call_args
assert call_kwargs[1].get("expert_map") is None or ( assert call_kwargs[1].get("expert_map") is None or (
len(call_kwargs[0]) > 0 len(call_kwargs[0]) > 0
) )
else:
# Non-EP path: should use legacy_routing
mock_legacy.assert_called_once()
mock_topk.assert_not_called()
mock_make_routing.assert_not_called()
...@@ -47,7 +47,6 @@ if has_triton_kernels(): ...@@ -47,7 +47,6 @@ if has_triton_kernels():
BIT, BIT,
Bitmatrix, Bitmatrix,
) )
from triton_kernels.topk import topk
try: try:
from triton_kernels.tensor import ( from triton_kernels.tensor import (
...@@ -89,6 +88,7 @@ def pack_bitmatrix( ...@@ -89,6 +88,7 @@ def pack_bitmatrix(
offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :]
mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :]
indices = tl.load(topk_ids + offsets, mask=mask, other=-1) indices = tl.load(topk_ids + offsets, mask=mask, other=-1)
valid = indices >= 0
div = indices // 32 div = indices // 32
rem = indices % 32 rem = indices % 32
one = tl.cast(1, tl.uint32) one = tl.cast(1, tl.uint32)
...@@ -99,8 +99,13 @@ def pack_bitmatrix( ...@@ -99,8 +99,13 @@ def pack_bitmatrix(
offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32)
# All topks that need to go into this column has the correct bit set. # All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor. # Other bits are 0. x is a 2D tensor.
# Guard with `valid` to prevent negative indices from producing
# spurious bits (on HIP, -1 // 32 == 0 and 1 << (-1 % 32) sets
# bit 31).
x = tl.where( x = tl.where(
div[:, :, None] == offs[None, None, :], (one << rem)[:, :, None], 0 valid[:, :, None] & (div[:, :, None] == offs[None, None, :]),
(one << rem)[:, :, None],
0,
) )
# Reduce x to get a single int32_t bitpack. # Reduce x to get a single int32_t bitpack.
y = tl.reduce_or(x, axis=1) y = tl.reduce_or(x, axis=1)
...@@ -108,93 +113,6 @@ def pack_bitmatrix( ...@@ -108,93 +113,6 @@ def pack_bitmatrix(
tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows)
def legacy_routing_from_bitmatrix(
bitmatrix: "Bitmatrix",
expt_scal: torch.Tensor,
expt_indx: torch.Tensor,
n_expts_tot: int,
n_expts_act: int,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix
return routing_from_bitmatrix(
bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act
)
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata,
)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing_from_sparsematrix(
sparse_logits: "SparseMatrix",
n_expts_tot: int,
n_expts_act: int,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
n_expts_tot,
n_expts_act,
ragged_batch_metadata,
)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing(
logits: torch.Tensor,
n_expts_act: int,
sm_first: bool = False,
) -> tuple["RoutingData", "GatherIndx", "ScatterIndx"]:
"""
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if use_legacy_triton_kernels:
from triton_kernels.routing import routing
return routing(logits, n_expts_act, sm_first=sm_first)
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first)
return legacy_routing_from_sparsematrix(
sparse_logits,
logits.shape[-1],
n_expts_act,
)
def triton_kernel_moe_forward( def triton_kernel_moe_forward(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1, # Tensor or triton_kernels.Tensor w1, # Tensor or triton_kernels.Tensor
...@@ -241,12 +159,6 @@ def triton_kernel_moe_forward( ...@@ -241,12 +159,6 @@ def triton_kernel_moe_forward(
unpadded_K_w2=unpadded_K_w2, unpadded_K_w2=unpadded_K_w2,
) )
if expert_map is not None:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
# indices. Split the routing into topk selection + expert_map
# remapping + local routing data construction (matching the
# approach used by OAITritonExperts.apply).
from triton_kernels.topk import topk as topk_fn from triton_kernels.topk import topk as topk_fn
sm_first = not renormalize sm_first = not renormalize
...@@ -261,6 +173,8 @@ def triton_kernel_moe_forward( ...@@ -261,6 +173,8 @@ def triton_kernel_moe_forward(
else: else:
topk_weights = topk_result.vals topk_weights = topk_result.vals
topk_ids_raw = topk_result.indx topk_ids_raw = topk_result.indx
if expert_map is not None:
# topk_ids_raw contains global expert IDs - remap to local. # topk_ids_raw contains global expert IDs - remap to local.
topk_ids = expert_map[topk_ids_raw.to(torch.long)] topk_ids = expert_map[topk_ids_raw.to(torch.long)]
local_num_experts = w1.shape[0] local_num_experts = w1.shape[0]
...@@ -271,8 +185,9 @@ def triton_kernel_moe_forward( ...@@ -271,8 +185,9 @@ def triton_kernel_moe_forward(
effective_expert_map = None effective_expert_map = None
effective_global_num_experts = local_num_experts effective_global_num_experts = local_num_experts
else: else:
routing_data, gather_idx, scatter_idx = legacy_routing( topk_ids = topk_ids_raw.to(torch.long)
gating_output, topk, sm_first=not renormalize routing_data, gather_idx, scatter_idx = make_routing_data(
topk_ids, topk_weights, gating_output.shape[-1]
) )
effective_expert_map = expert_map effective_expert_map = expert_map
effective_global_num_experts = global_num_experts effective_global_num_experts = global_num_experts
...@@ -539,10 +454,31 @@ def make_routing_data( ...@@ -539,10 +454,31 @@ def make_routing_data(
# matmul_ogs expects invalid topk_weights to be -1s # matmul_ogs expects invalid topk_weights to be -1s
topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights) topk_weights = torch.where(topk_ids == -1, -1.0, topk_weights)
routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(
if use_legacy_triton_kernels:
from triton_kernels.routing import routing_from_bitmatrix
return routing_from_bitmatrix(
bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk bitmatrix, topk_weights, topk_ids, num_local_experts, num_topk
) )
sparse_logits = SparseMatrix(indx=topk_ids, vals=topk_weights, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum,
dispatch_indx.shape[0],
)
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(
gate_scal,
ragged_batch_metadata.block_sizes,
num_local_experts,
num_topk,
ragged_batch_metadata,
)
gather_indx = GatherIndx(combine_indx, dispatch_indx)
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_indx, scatter_indx return routing_data, gather_indx, scatter_indx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment