Unverified Commit 679c6a3e authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[Bugfix][ROCm][MoE] Fix mxfp4 oracle regressions from #37128 (#37787)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 8bbb7c7f
...@@ -2526,6 +2526,7 @@ steps: ...@@ -2526,6 +2526,7 @@ steps:
- pytest -v -s -x lora/test_llm_with_multi_loras.py - pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py - pytest -v -s -x lora/test_olmoe_tp.py
- pytest -v -s -x lora/test_gptoss_tp.py - pytest -v -s -x lora/test_gptoss_tp.py
- pytest -v -s -x lora/test_qwen35_densemoel_lora.py
- label: Weight Loading Multiple GPU # 7.5m - label: Weight Loading Multiple GPU # 7.5m
......
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
import vllm import vllm
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
...@@ -69,6 +70,16 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: ...@@ -69,6 +70,16 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason=(
"Mxfp4 LoRA on ROCm is blocked by a spawn compatibility issue. "
"The fused_moe_lora Triton kernel crashes in spawned subprocesses, "
"and vLLM forces spawn mode when HIP is initialized before "
"multiprocessing. Fixing this requires either making the LoRA "
"Triton kernel spawn-safe or pre-warming the kernel cache."
),
)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False]) @pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
@pytest.mark.parametrize("specialize_active_lora", [True, False]) @pytest.mark.parametrize("specialize_active_lora", [True, False])
def test_gpt_oss_lora( def test_gpt_oss_lora(
......
...@@ -109,8 +109,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ...@@ -109,8 +109,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
else: # fall back to the default config else: # fall back to the default config
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_lora_config, try_get_optimal_moe_lora_config,
w1_shape=layer.w13_weight.size(), w1_shape=layer.w13_weight.shape,
w2_shape=layer.w2_weight.size(), w2_shape=layer.w2_weight.shape,
rank=rank, rank=rank,
top_k=top_k, top_k=top_k,
dtype=config_dtype, dtype=config_dtype,
......
...@@ -379,7 +379,11 @@ def _fused_moe_lora_kernel( ...@@ -379,7 +379,11 @@ def _fused_moe_lora_kernel(
) )
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
accumulator += tl.dot(a, b) # Cast operands to matching dtype for tl.dot. On ROCm, Triton's
# compiler may infer different types for a and b when merging
# if/else branches (TMA desc path returns fp32, tl.load returns
# the pointer's element type).
accumulator += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16))
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
......
...@@ -229,6 +229,9 @@ class FusedMoEQuantConfig: ...@@ -229,6 +229,9 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc _w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc _w2: FusedMoEQuantDesc
is_nvfp4_scale_swizzled: bool = True is_nvfp4_scale_swizzled: bool = True
# CK MXFP4 (gfx950) padding info for rocm_aiter_ops.fused_moe()
hidden_pad: int = 0
intermediate_pad: int = 0
def __post_init__(self): def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, ( assert not self.per_act_token_quant or self.block_shape is None, (
......
...@@ -257,7 +257,7 @@ def triton_kernel_moe_forward( ...@@ -257,7 +257,7 @@ def triton_kernel_moe_forward(
# sparse_logits.indx contains global expert IDs – remap to local. # sparse_logits.indx contains global expert IDs – remap to local.
topk_ids = expert_map[sparse_logits.indx.to(torch.long)] topk_ids = expert_map[sparse_logits.indx.to(torch.long)]
topk_weights = sparse_logits.vals topk_weights = sparse_logits.vals
local_num_experts = w1.size(0) local_num_experts = w1.shape[0]
routing_data, gather_idx, scatter_idx = make_routing_data( routing_data, gather_idx, scatter_idx = make_routing_data(
topk_ids, topk_weights, local_num_experts topk_ids, topk_weights, local_num_experts
) )
...@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular): ...@@ -604,8 +604,8 @@ class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
require a specialized implementation, like MarlinExperts, they are free require a specialized implementation, like MarlinExperts, they are free
to override this function. to override this function.
""" """
assert w1.dim() == 3 and w2.dim() == 3 assert len(w1.shape) == 3 and len(w2.shape) == 3
E, _, N = w1.size() E, _, N = w1.shape
K = a1.size(-1) K = a1.size(-1)
assert a1.dim() == 2 assert a1.dim() == 2
...@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts): ...@@ -683,7 +683,7 @@ class OAITritonExperts(BaseOAITritonExperts):
if expert_map is not None: if expert_map is not None:
topk_ids = expert_map[topk_ids] topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0) local_num_experts = w1.shape[0]
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
...@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts): ...@@ -781,7 +781,7 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
if expert_map is not None: if expert_map is not None:
topk_ids = expert_map[topk_ids] topk_ids = expert_map[topk_ids]
local_num_experts = w1.size(0) local_num_experts = w1.shape[0]
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
......
...@@ -567,6 +567,13 @@ class FusedMoE(CustomOp): ...@@ -567,6 +567,13 @@ class FusedMoE(CustomOp):
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
self.quant_method: FusedMoEMethodBase = _get_quant_method() self.quant_method: FusedMoEMethodBase = _get_quant_method()
# Quant methods (e.g. Mxfp4MoEMethod) may round up hidden_dim
# and intermediate_size in moe_config during __init__. Sync
# self.hidden_size so downstream consumers (e.g. LoRA) see the
# padded value.
if self.moe_config.hidden_dim != self.hidden_size:
self.hidden_size = self.moe_config.hidden_dim
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike(): if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
raise NotImplementedError( raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA and ROCm for now" "is_act_and_mul=False is supported only for CUDA and ROCm for now"
...@@ -586,7 +593,7 @@ class FusedMoE(CustomOp): ...@@ -586,7 +593,7 @@ class FusedMoE(CustomOp):
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
"hidden_size": hidden_size, "hidden_size": self.hidden_size,
"unpadded_hidden_size": unpadded_hidden_size, "unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition, "intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": params_dtype, "params_dtype": params_dtype,
......
...@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts): ...@@ -768,8 +768,8 @@ class FusedMoEExpertsModular(FusedMoEExperts):
require a specialized implementation, like MarlinExperts, they are free require a specialized implementation, like MarlinExperts, they are free
to override this function. to override this function.
""" """
assert w1.dim() == 3 and w2.dim() == 3 assert len(w1.shape) == 3 and len(w2.shape) == 3
E, N, _ = w1.size() E, N, _ = w1.shape
K = a1.size(-1) K = a1.size(-1)
if a1.dim() == 2: if a1.dim() == 2:
...@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl: ...@@ -1349,7 +1349,7 @@ class FusedMoEKernelModularImpl:
else: else:
output = torch.empty_like(hidden_states) output = torch.empty_like(hidden_states)
local_num_experts = w1.size(0) local_num_experts = w1.shape[0]
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
......
...@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend( ...@@ -212,7 +212,11 @@ def select_mxfp4_moe_backend(
# LoRA: separate experts backend path # LoRA: separate experts backend path
if config.is_lora_enabled: if config.is_lora_enabled:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
raise NotImplementedError("Mxfp4 LoRA only supported on CUDA Platform.") # ROCm: Triton mxfp4 LoRA hits GPU memory faults due to
# triton_kernels.tensor.Tensor / HIP read-only page issues
# during weight swizzle and LoRA forward. Needs work from
# the triton_kernels/aiter side.
raise NotImplementedError("Mxfp4 LoRA is currently only supported on CUDA.")
if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported:
logger.info_once("Using Triton backend for mxfp4 lora") logger.info_once("Using Triton backend for mxfp4 lora")
return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls( return Mxfp4MoeBackend.TRITON_UNFUSED, backend_to_kernel_cls(
...@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config( ...@@ -775,6 +779,8 @@ def make_mxfp4_moe_quant_config(
w2_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"],
w1_bias: torch.Tensor | None = None, w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
"""Create a FusedMoEQuantConfig for the given MXFP4 backend.""" """Create a FusedMoEQuantConfig for the given MXFP4 backend."""
if mxfp4_backend in ( if mxfp4_backend in (
...@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config( ...@@ -796,12 +802,16 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK, Mxfp4MoeBackend.CK,
): ):
return mxfp4_w4a16_moe_quant_config( config = mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias, w1_bias=w1_bias,
w2_bias=w2_bias, w2_bias=w2_bias,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
) )
if mxfp4_backend == Mxfp4MoeBackend.CK:
config.hidden_pad = hidden_pad
config.intermediate_pad = intermediate_pad
return config
else: else:
return ocp_mx_moe_quant_config( return ocp_mx_moe_quant_config(
quant_dtype="mxfp4", quant_dtype="mxfp4",
......
...@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts( ...@@ -292,6 +292,8 @@ def rocm_aiter_fused_experts(
doweight_stage1=apply_router_weight_on_input, doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens, num_local_tokens=num_local_tokens,
output_dtype=output_dtype, output_dtype=output_dtype,
hidden_pad=quant_config.hidden_pad,
intermediate_pad=quant_config.intermediate_pad,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None, bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None, bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
) )
...@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular): ...@@ -332,7 +334,15 @@ class AiterExperts(mk.FusedMoEExpertsModular):
(kFp8StaticChannelSym, kFp8DynamicTokenSym), (kFp8StaticChannelSym, kFp8DynamicTokenSym),
(kMxfp4Static, None), (kMxfp4Static, None),
] ]
return (weight_key, activation_key) in SUPPORTED_W_A if (weight_key, activation_key) not in SUPPORTED_W_A:
return False
# CK MXFP4 MoE kernels are only supported on gfx950.
if weight_key == kMxfp4Static:
from vllm.platforms.rocm import on_gfx950
if not on_gfx950():
return False
return True
@staticmethod @staticmethod
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
......
...@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -158,6 +158,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad intermediate_size_per_partition_after_pad
) )
# CK (gfx950) padding info for rocm_aiter_ops.fused_moe()
self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
self.intermediate_pad = (
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
)
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.zeros( torch.zeros(
...@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -362,6 +368,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w2_scale=w2_scale, w2_scale=w2_scale,
w1_bias=w1_bias, w1_bias=w1_bias,
w2_bias=w2_bias, w2_bias=w2_bias,
hidden_pad=self.hidden_pad,
intermediate_pad=self.intermediate_pad,
) )
def select_gemm_impl( def select_gemm_impl(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment