Unverified Commit 103f0de5 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

[ROCm][Quantization][1/N] Refactor quark_moe w_mxfp4 w/ oracle (#38774)


Signed-off-by: default avatarBowen Bao <bowenbao@amd.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 32e0c0bf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend aiter"
env:
VLLM_ROCM_USE_AITER: "1"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold: 0.568
reasoning_effort: low
server_args: "--attention-backend ROCM_AITER_UNIFIED_ATTN --moe-backend triton"
\ No newline at end of file
# GFX950 model configurations for GPQA evaluation # GFX950 model configurations for GPQA evaluation
# Tests different environment variable combinations # Tests different environment variable combinations
gpt-oss-20b-rocm-baseline.yaml gpt-oss-20b-rocm-baseline.yaml
gpt-oss-20b-rocm-mxfp4-fp8.yaml gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
\ No newline at end of file gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
...@@ -54,8 +54,8 @@ class Mxfp4MoeBackend(Enum): ...@@ -54,8 +54,8 @@ class Mxfp4MoeBackend(Enum):
# Marlin # Marlin
BATCHED_MARLIN = "BATCHED_MARLIN" BATCHED_MARLIN = "BATCHED_MARLIN"
MARLIN = "MARLIN" MARLIN = "MARLIN"
# ROCm AITER (CK) # ROCm AITER
CK = "CK" AITER = "AITER"
# Triton # Triton
TRITON = "TRITON" TRITON = "TRITON"
TRITON_UNFUSED = "TRITON_UNFUSED" TRITON_UNFUSED = "TRITON_UNFUSED"
...@@ -130,7 +130,7 @@ def backend_to_kernel_cls( ...@@ -130,7 +130,7 @@ def backend_to_kernel_cls(
return [BatchedMarlinExperts] return [BatchedMarlinExperts]
elif backend == Mxfp4MoeBackend.CK: elif backend == Mxfp4MoeBackend.AITER:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts, AiterExperts,
) )
...@@ -155,7 +155,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend: ...@@ -155,7 +155,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, "flashinfer_cutlass_afp8": Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
"triton": Mxfp4MoeBackend.TRITON, "triton": Mxfp4MoeBackend.TRITON,
"marlin": Mxfp4MoeBackend.MARLIN, "marlin": Mxfp4MoeBackend.MARLIN,
"ck": Mxfp4MoeBackend.CK, "aiter": Mxfp4MoeBackend.AITER,
"xpu": Mxfp4MoeBackend.XPU, "xpu": Mxfp4MoeBackend.XPU,
} }
if backend := mapping.get(runner_backend): if backend := mapping.get(runner_backend):
...@@ -173,7 +173,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: ...@@ -173,7 +173,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
""" """
_AVAILABLE_BACKENDS = [ _AVAILABLE_BACKENDS = [
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.CK, Mxfp4MoeBackend.AITER,
Mxfp4MoeBackend.TRITON, Mxfp4MoeBackend.TRITON,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.TRITON_UNFUSED, Mxfp4MoeBackend.TRITON_UNFUSED,
...@@ -656,7 +656,7 @@ def convert_to_mxfp4_moe_kernel_format( ...@@ -656,7 +656,7 @@ def convert_to_mxfp4_moe_kernel_format(
w2_bias, w2_bias,
) )
elif mxfp4_backend == Mxfp4MoeBackend.CK: elif mxfp4_backend == Mxfp4MoeBackend.AITER:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
if w13_bias is not None: if w13_bias is not None:
...@@ -794,7 +794,7 @@ def make_mxfp4_moe_quant_config( ...@@ -794,7 +794,7 @@ def make_mxfp4_moe_quant_config(
Mxfp4MoeBackend.TRITON_UNFUSED, Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16, Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_BF16,
Mxfp4MoeBackend.CK, Mxfp4MoeBackend.AITER,
): ):
return mxfp4_w4a16_moe_quant_config( return mxfp4_w4a16_moe_quant_config(
w1_bias=w1_bias, w1_bias=w1_bias,
......
...@@ -5,6 +5,7 @@ from typing import Any ...@@ -5,6 +5,7 @@ from typing import Any
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
...@@ -27,7 +28,11 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -27,7 +28,11 @@ from vllm.model_executor.layers.fused_moe.config import (
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend, Mxfp4MoeBackend,
convert_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
mxfp4_round_up_hidden_size_and_intermediate_size, mxfp4_round_up_hidden_size_and_intermediate_size,
select_mxfp4_moe_backend, select_mxfp4_moe_backend,
) )
...@@ -47,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -47,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, per_tensor_dequantize,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
...@@ -699,9 +704,16 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -699,9 +704,16 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
f"Please check that the combination is supported in OCP_MX_Scheme." f"Please check that the combination is supported in OCP_MX_Scheme."
) )
self.mxfp4_backend: Mxfp4MoeBackend | None = None self.mxfp4_backend: Mxfp4MoeBackend = Mxfp4MoeBackend.NONE
self.experts_cls: type[mk.FusedMoEExperts] | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
# Used for triton kernel precision configs
self.w13_precision_config = None
self.w2_precision_config = None
if self.ocp_mx_scheme == "w_mxfp4": if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, _ = select_mxfp4_moe_backend(moe) self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"): elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes. # TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE self.mxfp4_backend = Mxfp4MoeBackend.NONE
...@@ -738,9 +750,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -738,9 +750,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
not current_platform.supports_mx() not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4") or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and ( ) and (
self.mxfp4_backend is None self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
or self.mxfp4_backend is Mxfp4MoeBackend.NONE
or not self.use_rocm_aiter_moe
) )
if self.emulate: if self.emulate:
...@@ -944,11 +954,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -944,11 +954,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w2_input_scale, requires_grad=False w2_input_scale, requires_grad=False
) )
# secondly, process mxfp weights # For w_mxfp4, use oracle functions
if (
self.ocp_mx_scheme == "w_mxfp4"
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
):
self._setup_kernel_via_oracle(layer)
return
# TODO(bowenbao): gradually migrate to oracles.
# secondly, process mxfp weights for other schemes
if self.emulate: if self.emulate:
# Build quant config for emulation path
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
return return
# Existing AITER path for w_mxfp4_a_mxfp4 and other schemes
from aiter.utility.fp4_utils import e8m0_shuffle from aiter.utility.fp4_utils import e8m0_shuffle
# Pre-shuffle weight scales # Pre-shuffle weight scales
...@@ -980,11 +1002,87 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -980,11 +1002,87 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True layer.w2_weight.is_shuffled = True
# Build quant config for AITER path
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
torch.accelerator.empty_cache() torch.accelerator.empty_cache()
def _setup_kernel_via_oracle(self, layer: FusedMoE):
"""Setup kernel using oracle functions for w_mxfp4 scheme."""
w13 = layer.w13_weight
w2 = layer.w2_weight
w13_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
# Convert weights to kernel format
w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = (
convert_to_mxfp4_moe_kernel_format(
mxfp4_backend=self.mxfp4_backend,
layer=layer,
w13_weight=w13,
w2_weight=w2,
w13_weight_scale=w13_scale,
w2_weight_scale=w2_scale,
w13_bias=w13_bias,
w2_bias=w2_bias,
)
)
# For TRITON backends, weights are wrapped tensors from triton_kernels
# that don't support .detach(). Manually assign parameters.
if self.mxfp4_backend not in TRITON_BACKENDS:
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
else:
layer.w13_weight = w13
layer.w2_weight = w2
self.w13_precision_config = w13_scale
self.w2_precision_config = w2_scale
if w13_bias is not None and w2_bias is not None:
replace_parameter(layer, "w13_bias", w13_bias)
replace_parameter(layer, "w2_bias", w2_bias)
# Build quant config and kernel
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None and self.experts_cls is not None:
self.moe_kernel = make_mxfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
mxfp4_backend=self.mxfp4_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
# For w_mxfp4 with oracle backend, use oracle function
if (
self.ocp_mx_scheme == "w_mxfp4"
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
):
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
if self.mxfp4_backend in TRITON_BACKENDS:
w1_scale = self.w13_precision_config
w2_scale = self.w2_precision_config
return make_mxfp4_moe_quant_config(
mxfp4_backend=self.mxfp4_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
# Existing code for other schemes
# TODO(bowenbao): kept for emulation fallback, to be refactored into
# dedicated emulation backend.
if self.ocp_mx_scheme == "w_mxfp4": if self.ocp_mx_scheme == "w_mxfp4":
return mxfp4_w4a16_moe_quant_config( return mxfp4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
...@@ -1020,6 +1118,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -1020,6 +1118,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
block_shape=None, block_shape=None,
) )
@property
def is_monolithic(self) -> bool:
if self.moe_kernel is not None:
return self.moe_kernel.is_monolithic
return False
def apply( def apply(
self, self,
layer: FusedMoE, layer: FusedMoE,
...@@ -1028,6 +1132,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -1028,6 +1132,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
# For w_mxfp4 with oracle kernel
if self.moe_kernel is not None:
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
shared_experts_input=shared_experts_input,
)
# Existing code for emulation/AITER paths
if not self.emulate: if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts, rocm_aiter_fused_experts,
...@@ -1061,6 +1181,25 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -1061,6 +1181,25 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
router_logits=router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment