Unverified Commit d622e27d authored by fxmarty-amd's avatar fxmarty-amd Committed by GitHub
Browse files

[NVFP4] NVFP4 MOE emulation fallback for H100/MI300/MI350, standardize...


[NVFP4] NVFP4 MOE emulation fallback for H100/MI300/MI350, standardize `TritonExperts` usage for OCP MX emulation (#35737)
Signed-off-by: default avatarFelix Marty <Felix.Marty@amd.com>
Signed-off-by: default avatarfxmarty-amd <felmarty@amd.com>
Co-authored-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 5f76b3fb
......@@ -2,3 +2,5 @@ DeepSeek-R1-TP_MI325.yaml
DeepSeek-R1-DP_MI325.yaml
DeepSeek-V3.2-TP_MI325.yaml
DeepSeek-V3.2-DP_MI325.yaml
Qwen3-30B-A3B-NVFP4.yaml
Qwen3.5-35B-A3B-MXFP4-TP2.yaml
\ No newline at end of file
......@@ -120,3 +120,26 @@ def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
assert output[0][1] == "1 2 3 4 5 6"
@pytest.mark.parametrize(
"model",
[
"nvidia/Qwen3-30B-A3B-NVFP4",
"RedHatAI/Qwen3-30B-A3B-NVFP4",
],
)
@pytest.mark.parametrize("backend", ["emulation"])
@pytest.mark.skipif(
not current_platform.is_rocm(),
reason="NVFP4 MOE emulation is only useful on AMD Instinct MI3xx",
)
def test_nvfp4_moe(vllm_runner, model, backend, monkeypatch):
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
with vllm_runner(
model,
moe_backend=backend,
load_format="dummy",
hf_overrides={"num_hidden_layers": 2},
) as llm:
_ = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
......@@ -115,6 +115,7 @@ MoEBackend = Literal[
"flashinfer_cutedsl",
"marlin",
"aiter",
"emulation",
]
......@@ -142,7 +143,10 @@ class KernelConfig:
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)
- "marlin": Use Marlin kernels (weight-only quantization)
- "aiter": Use AMD AITer kernels (ROCm only)"""
- "aiter": Use AMD AITer kernels (ROCm only)
- "emulation": use BF16/FP16 GEMM, dequantizing weights and
running QDQ on activations.
"""
@field_validator("moe_backend", mode="before")
@classmethod
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
NVFP4 quantization emulation for MoE.
This file implements NVFP4 emulation for NVFP4 MOE in case the hardware used does not
natively support NVFP4 MOE.
Weights are dequantized on the fly during each forward, we fall back to calling
`TritonExperts` using BF16, and fake NVFP4 quantize-dequantize
is applied on `a13`, `a2`.
"""
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
dequantize_to_dtype,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
logger = init_logger(__name__)
class Nvfp4QuantizationEmulationTritonExperts(TritonExperts):
"""
Extension of TritonExperts to support emulated NVFP4 MoE experts.
It may be used for NVFP4 models when the device does not have
native support for this dtype.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
logger.warning_once(
"Using Nvfp4QuantizationEmulationTritonExperts MOE backend. This will"
" dequantize weights on the fly and may be slower than native"
" quantized MOE. Consider using a device with native quantization"
" support (e.g. Nvidia Blackwell) for better performance."
)
# `TritonExperts.apply` expects pre-dequantized weights,
# which we handle in `apply` below.
self.w1_scale_val = self.quant_config.w1_scale
self.w2_scale_val = self.quant_config.w2_scale
self.quant_config._w1.scale = None
self.quant_config._w2.scale = None
self.quantization_emulation = True
@property
def quant_dtype(self) -> torch.dtype | str | None:
return "nvfp4"
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
"""
Apply emulated quantized MoE computation.
This dequantizes the weights on the fly and calls fused_experts_impl
with activation quantization support.
"""
# Dequantize weights if they are quantized
# For NVFP4, weights are packed in uint8 format
# w1 shape: [num_experts, 2*intermediate_size, hidden_size//2]
# w2 shape: [num_experts, hidden_size, intermediate_size//2]
assert w1.dtype == torch.uint8
assert w2.dtype == torch.uint8
# Dequantize w1 from packed NVFP4 to fp16/bf16
w13_global_scale = self.quant_config.g1_alphas
w1_dequant = dequantize_to_dtype(
tensor_fp4=w1,
tensor_sf=self.w1_scale_val,
global_scale=w13_global_scale,
dtype=hidden_states.dtype,
block_size=16,
swizzle=False,
)
# Dequantize w2 from packed NVFP4 to fp16/bf16
w2_global_scale = self.quant_config.g2_alphas
w2_dequant = dequantize_to_dtype(
tensor_fp4=w2,
tensor_sf=self.w2_scale_val,
global_scale=w2_global_scale,
dtype=hidden_states.dtype,
block_size=16,
swizzle=False,
)
hidden_states, _ = moe_kernel_quantize_input(
A=hidden_states,
A_scale=self.quant_config.a1_gscale,
quant_dtype="nvfp4",
per_act_token_quant=False,
quantization_emulation=True,
)
# Activation quantization/dequantization is deferred to
# `moe_kernel_quantize_input` in TritonExperts.apply.
super().apply(
output=output,
hidden_states=hidden_states,
w1=w1_dequant,
w2=w2_dequant,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=None,
a2_scale=self.quant_config.a2_gscale,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
OCP MX quantization emulation for MoE.
This file implements OCP MX (MXFP4/MXFP6) emulation for MoE in case the
hardware used does not natively support OCP MX MoE.
Weights are dequantized on the fly during each forward, we fall back to calling
`TritonExperts` using BF16, and fake OCP MX quantize-dequantize
is applied on activations via `moe_kernel_quantize_input`.
"""
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_Scheme,
)
logger = init_logger(__name__)
class OCP_MXQuantizationEmulationTritonExperts(TritonExperts):
"""
Extension of TritonExperts to support emulated OCP MX MoE experts.
It may be used for OCP MX (MXFP4/MXFP6) models when the device does not
have native support for these dtypes.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(moe_config, quant_config)
logger.warning_once(
"Using OCP_MXQuantizationEmulationTritonExperts MOE backend. This"
" will dequantize weights on the fly and may be slower than native"
" quantized MOE. Consider using a device with native OCP MX"
" quantization support for better performance."
)
self.ocp_mx_scheme = quant_config.ocp_mx_scheme
assert self.ocp_mx_scheme is not None, (
"ocp_mx_scheme must be set in quant_config for"
" OCP_MXQuantizationEmulationTritonExperts"
)
# `TritonExperts.apply` expects pre-dequantized weights,
# which we handle in `apply` below.
self.w1_scale_val = self.quant_config.w1_scale
self.w2_scale_val = self.quant_config.w2_scale
self.quant_config._w1.scale = None
self.quant_config._w2.scale = None
self.quantization_emulation = True
if self.ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
}:
# Weight has to be dequantized for mxfp4 emulation.
self._quant_dtype = "mxfp4"
elif self.ocp_mx_scheme in [
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3,
]:
self._quant_dtype = "mxfp6"
elif self.ocp_mx_scheme in [
OCP_MX_Scheme.w_mxfp4_a_fp8,
OCP_MX_Scheme.w_mxfp6_e3m2_a_fp8,
]:
# TODO: double check this one
self._quant_dtype = "mxfp8"
@property
def quant_dtype(self) -> torch.dtype | str | None:
return self._quant_dtype
@property
def expects_unquantized_inputs(self) -> bool:
return True
@staticmethod
def _supports_quant_scheme(
weight_key,
activation_key,
) -> bool:
# This class is used for emulation only - the oracle selects it
# directly rather than via quant scheme matching.
return True
def _dequantize_weights(
self,
w: torch.Tensor,
w_scale: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Dequantize weights based on the OCP MX scheme."""
if self.ocp_mx_scheme.startswith("w_mxfp4"): # type: ignore[union-attr]
return dequant_mxfp4(w, w_scale, dtype)
elif self.ocp_mx_scheme.startswith("w_mxfp6_e3m2"): # type: ignore[union-attr]
return dequant_mxfp6(w, w_scale, quant_dtype="fp6_e3m2", float_dtype=dtype)
elif self.ocp_mx_scheme.startswith("w_mxfp6_e2m3"): # type: ignore[union-attr]
return dequant_mxfp6(w, w_scale, quant_dtype="fp6_e2m3", float_dtype=dtype)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={self.ocp_mx_scheme}")
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
"""
Apply emulated quantized MoE computation.
This dequantizes the weights on the fly and calls TritonExperts.apply
with activation quantization support.
"""
assert w1.dtype == torch.uint8
assert w2.dtype == torch.uint8
# Dequantize w1 and w2 from packed OCP MX format to bf16/fp16
w1_dequant = self._dequantize_weights(
w1, self.w1_scale_val, hidden_states.dtype
)
w2_dequant = self._dequantize_weights(
w2, self.w2_scale_val, hidden_states.dtype
)
# Apply activation QDQ if needed by the OCP MX scheme
hidden_states, _ = moe_kernel_quantize_input(
A=hidden_states,
A_scale=None,
quant_dtype=self.quant_config.quant_dtype,
per_act_token_quant=False,
ocp_mx_scheme=self.ocp_mx_scheme,
quantization_emulation=True,
)
# Activation quantization/dequantization is deferred to
# `moe_kernel_quantize_input` in TritonExperts.apply.
super().apply(
output=output,
hidden_states=hidden_states,
w1=w1_dequant,
w2=w2_dequant,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
a1q_scale=None,
a2_scale=None,
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import flashinfer
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
......@@ -188,6 +188,8 @@ class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModula
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
import flashinfer
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert a1q_scale is not None
assert self.quant_config.w1_scale is not None
......@@ -306,6 +308,8 @@ class TrtLlmNvFp4ExpertsMonolithic(
routed_scaling_factor: float | None = None,
topk_group: int | None = None,
) -> torch.Tensor:
import flashinfer
assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
assert a1q_scale is not None
assert self.quant_config.w1_scale is not None
......
......@@ -36,8 +36,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
disable_inplace,
moe_kernel_quantize_input,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
......@@ -1708,22 +1706,18 @@ def fused_experts_impl(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
if ocp_mx_scheme is not None:
raise NotImplementedError(
f"Using ocp_mx_scheme={ocp_mx_scheme} in functional fused_experts call is "
"deprecated. Please use OCP_MXQuantizationEmulationTritonExperts."
)
# Convert string activation to enum for internal use
activation_enum = MoEActivation.from_str(activation)
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme.startswith("w_mxfp4"):
# 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme.startswith("w_mxfp6"):
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch"
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
......@@ -1748,7 +1742,6 @@ def fused_experts_impl(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
ocp_mx_scheme=ocp_mx_scheme,
dtype=hidden_states.dtype,
)
......@@ -1757,7 +1750,7 @@ def fused_experts_impl(
quant_dtype = _get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
ocp_mx_scheme=ocp_mx_scheme,
ocp_mx_scheme=None,
)
get_config_func = functools.partial(
......@@ -1802,44 +1795,12 @@ def fused_experts_impl(
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme.startswith("w_mxfp4"):
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
qhidden_states, a1q_scale = moe_kernel_quantize_input(
A=hidden_states,
A_scale=a1_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = _prepare_expert_assignment(
......@@ -1889,7 +1850,6 @@ def fused_experts_impl(
quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
)
if expert_map is not None:
......@@ -1935,6 +1895,9 @@ class TritonExperts(mk.FusedMoEExpertsModular):
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
# Whether quantized MOE runs natively, or through
# higher-precision + activation QDQ.
self.quantization_emulation = False
super().__init__(moe_config, quant_config)
@staticmethod
......@@ -2144,6 +2107,7 @@ class TritonExperts(mk.FusedMoEExpertsModular):
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
quantization_emulation=self.quantization_emulation,
)
invoke_fused_moe_triton_kernel(
......
......@@ -62,6 +62,8 @@ class Mxfp4MoeBackend(Enum):
TRITON_UNFUSED = "TRITON_UNFUSED"
# XPU
XPU = "XPU"
# Emulation
EMULATION = "EMULATION"
# Backends that share the same TRTLLM weight format
......@@ -143,6 +145,13 @@ def backend_to_kernel_cls(
return [XPUExpertsMXFp4]
elif backend == Mxfp4MoeBackend.EMULATION:
from vllm.model_executor.layers.fused_moe.experts.ocp_mx_emulation_moe import (
OCP_MXQuantizationEmulationTritonExperts,
)
return [OCP_MXQuantizationEmulationTritonExperts]
else:
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")
......@@ -158,6 +167,7 @@ def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"marlin": Mxfp4MoeBackend.MARLIN,
"aiter": Mxfp4MoeBackend.AITER,
"xpu": Mxfp4MoeBackend.XPU,
"emulation": Mxfp4MoeBackend.EMULATION,
}
if backend := mapping.get(runner_backend):
return backend
......@@ -181,6 +191,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
Mxfp4MoeBackend.EMULATION,
]
return _AVAILABLE_BACKENDS
......@@ -768,6 +779,17 @@ def convert_gpt_oss_weight_to_mxfp4_moe_kernel_format(
w13_bias,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.EMULATION:
# No additional transformation needed for emulation backend,
# weights are dequantized on the fly in the experts class.
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend: {mxfp4_backend}: "
......
......@@ -45,6 +45,7 @@ class NvFp4MoeBackend(Enum):
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
VLLM_CUTLASS = "VLLM_CUTLASS"
MARLIN = "MARLIN"
EMULATION = "EMULATION"
FLASHINFER_NVFP4_MOE_BACKENDS = [
......@@ -118,6 +119,12 @@ def backend_to_kernel_cls(
)
return [MarlinExperts]
elif backend == NvFp4MoeBackend.EMULATION:
from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import (
Nvfp4QuantizationEmulationTritonExperts,
)
return [Nvfp4QuantizationEmulationTritonExperts]
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
......@@ -130,6 +137,7 @@ def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
"marlin": NvFp4MoeBackend.MARLIN,
"emulation": NvFp4MoeBackend.EMULATION,
}
if backend := mapping.get(runner_backend):
return backend
......@@ -157,6 +165,7 @@ def select_nvfp4_moe_backend(
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.MARLIN,
NvFp4MoeBackend.EMULATION,
]
# NOTE(rob): this is kind of a hack. We need to peak into
......@@ -372,6 +381,30 @@ def convert_to_nvfp4_moe_kernel_format(
w2_scale_2=w2_scale_2,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.EMULATION:
if a13_scale is None or a2_scale is None:
raise ValueError(
"Activation global scales should not be None, got"
f" a13_scale={a13_scale}, a2_scale={a2_scale}"
)
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the activation global scale for inputs are different"
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
)
# 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py.
# 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant
# use the inverse scale directly (large global scale).
# NOTE: Before this point, `a13_scale` and `a2_scale` are such that:
# `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`,
# and `global_scale[expert_id]` are small (~1e-4).
# Taking the largest global scale likely results in overflowing the FP8 range
# for other experts - other selection strategies may be used.
a13_scale = 1.0 / a13_scale.max().to(torch.float32)
a2_scale = 1.0 / a2_scale.max().to(torch.float32)
else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
......@@ -403,6 +436,15 @@ def make_nvfp4_moe_quant_config(
w1_scale=w13_scale,
w2_scale=w2_scale,
)
elif backend == NvFp4MoeBackend.EMULATION:
return nvfp4_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
a1_gscale=a13_scale,
a2_gscale=a2_scale,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
# The expert's process_weights_after_loading will fuse activation
......
......@@ -22,6 +22,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
ref_nvfp4_quant_dequant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
......@@ -253,6 +256,7 @@ def moe_kernel_quantize_input(
block_shape: list[int] | None = None,
is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
quantization_emulation: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
if ocp_mx_scheme is not None:
......@@ -274,16 +278,41 @@ def moe_kernel_quantize_input(
# activation quantization below.
if quant_dtype == current_platform.fp8_dtype():
if quantization_emulation:
raise NotImplementedError(
f"moe_kernel_quantize_input does not support quant_dtype={quant_dtype}"
" MOE quantization emulation. Please open an issue."
)
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
if quantization_emulation:
raise NotImplementedError(
"moe_kernel_quantize_input does not support quant_dtype=torch.int8"
" MOE quantization emulation. Please open an issue."
)
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "nvfp4":
return _nvfp4_quantize(A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled)
if not quantization_emulation:
return _nvfp4_quantize(
A, A_scale, is_sf_swizzled_layout=is_fp4_scale_swizzled
)
else:
return ref_nvfp4_quant_dequant(A, A_scale, block_size=16)
elif quant_dtype == "mxfp4":
if not quantization_emulation:
raise NotImplementedError(
"moe_kernel_quantize_input should not be used for native"
" quant_dtype='mxfp4' MOE. Please open an issue."
)
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp8":
# TODO: `quant_dtype == "mxfp8"` is ambiguous,
# should be fp8_e4m3. OCP MX also defines `fp8_e5m2`.
if quantization_emulation:
raise NotImplementedError(
"moe_kernel_quantize_input does not support quant_dtype='mxfp8' MOE "
"quantization emulation. Please open an issue."
)
return _mxfp8_e4m3_quantize(
A,
A_scale,
......@@ -292,8 +321,20 @@ def moe_kernel_quantize_input(
is_sf_swizzled_layout=is_fp4_scale_swizzled,
)
elif quant_dtype == "mxfp6_e3m2":
if not quantization_emulation:
raise NotImplementedError(
"moe_kernel_quantize_input should not be used for native "
" quant_dtype='mxfp6_e3m2'MOE. Please open an issue."
)
return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp6_e2m3":
if not quantization_emulation:
raise NotImplementedError(
"moe_kernel_quantize_input should not be used for native"
" quant_dtype='mxfp6_e2m3' MOE. Please open an issue."
)
return _mxfp6_e2m3_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale
......
......@@ -30,6 +30,7 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_m
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
TRITON_BACKENDS,
Mxfp4MoeBackend,
backend_to_kernel_cls,
convert_gpt_oss_weight_to_mxfp4_moe_kernel_format,
make_mxfp4_moe_kernel,
make_mxfp4_moe_quant_config,
......@@ -986,6 +987,8 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
f"Please check that the combination is supported in OCP_MX_Scheme."
)
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes,
# use kernel abstraction for all OCP MX MOE implementations.
self.mxfp4_backend: Mxfp4MoeBackend = Mxfp4MoeBackend.NONE
self.experts_cls: type[mk.FusedMoEExperts] | None = None
self.moe_kernel: mk.FusedMoEKernel | None = None
......@@ -994,12 +997,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.w13_precision_config = None
self.w2_precision_config = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
elif self.ocp_mx_scheme.startswith("w_mxfp4"):
# TODO(bowenbao): refactor and introduce backends for other OCP MX schemes.
self.mxfp4_backend = Mxfp4MoeBackend.NONE
if self.input_quant is not None:
self.static_input_scales = not self.input_quant.get("is_dynamic")
else:
......@@ -1035,6 +1032,18 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
)
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend, self.experts_cls = select_gpt_oss_mxfp4_moe_backend(moe)
if self.emulate:
# We use the same code path between MXFP4/MXFP6 emulation.
self.mxfp4_backend = Mxfp4MoeBackend.EMULATION
# TODO: Remove `self.mxfp4_backend != Mxfp4MoeBackend.NONE` and make it so that
# all MXFP4 backends use the kernel abstraction.
if self.mxfp4_backend != Mxfp4MoeBackend.NONE:
self.experts_cls = backend_to_kernel_cls(self.mxfp4_backend)[0]
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
......@@ -1063,7 +1072,12 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
act_dtype=act_dtype,
moe_parallel_config=moe_parallel_config,
)
if self.mxfp4_backend is not None:
# In case quantization emulation backend is used, there is no need to apply
# MXFP4-specific padding logic as the compute happens in higher precision.
if (
self.mxfp4_backend is not None
and self.mxfp4_backend != Mxfp4MoeBackend.EMULATION
):
hidden_size, intermediate_size_per_partition = (
mxfp4_round_up_hidden_size_and_intermediate_size(
self.mxfp4_backend, hidden_size, intermediate_size_per_partition
......@@ -1237,7 +1251,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
# For w_mxfp4, use oracle functions
if (
if self.emulate or (
self.ocp_mx_scheme == "w_mxfp4"
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
):
......@@ -1245,13 +1259,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
return
# TODO(bowenbao): gradually migrate to oracles.
# secondly, process mxfp weights for other schemes
if self.emulate:
# Build quant config for emulation path
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
torch.accelerator.empty_cache()
return
# Existing AITER path for w_mxfp4_a_mxfp4 and other schemes
from aiter.utility.fp4_utils import e8m0_shuffle
......@@ -1345,9 +1352,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
# For w_mxfp4 with oracle backend, use oracle function
if (
self.ocp_mx_scheme == "w_mxfp4"
and self.mxfp4_backend != Mxfp4MoeBackend.NONE
if self.ocp_mx_scheme == "w_mxfp4" and self.mxfp4_backend not in (
Mxfp4MoeBackend.NONE,
Mxfp4MoeBackend.EMULATION,
):
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
......@@ -1362,9 +1369,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w2_bias=getattr(layer, "w2_bias", None),
)
# Existing code for other schemes
# TODO(bowenbao): kept for emulation fallback, to be refactored into
# dedicated emulation backend.
# Emulation and other schemes
if self.ocp_mx_scheme == "w_mxfp4":
return mxfp4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
......@@ -1414,7 +1419,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
# For w_mxfp4 with oracle kernel
# For oracle kernel or emulation kernel
if self.moe_kernel is not None:
return self.moe_kernel.apply(
hidden_states=x,
......@@ -1429,8 +1434,8 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts_input=shared_experts_input,
)
# Existing code for emulation/AITER paths
if not self.emulate:
# AITER path
# TODO: Refactor this to use modular MOE kernel as well.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts,
)
......@@ -1446,22 +1451,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
moe_config=layer.moe_config,
expert_map=layer.expert_map,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
def apply_monolithic(
self,
......
......@@ -53,26 +53,52 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
def dequantize_to_dtype(
tensor_fp4: torch.Tensor,
tensor_sf: torch.Tensor,
global_scale: torch.Tensor | float,
global_scale: torch.Tensor,
dtype: torch.dtype,
block_size: int = 16,
swizzle: bool | None = True,
):
"""Dequantize the fp4 tensor back to high precision."""
"""Dequantize the fp4 tensor back to high precision.
Supports both 2D and 3D inputs:
- 2D: [m, packed_k] -> [m, k]
- 3D: [dim0, m, packed_k] -> [dim0, m, k]
"""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
# We handle 3D tensors reshaping them to 2D.
is_3d = tensor_fp4.ndim == 3
if is_3d:
dim0, m, packed_k = tensor_fp4.shape
tensor_fp4 = tensor_fp4.reshape(-1, packed_k)
tensor_sf = tensor_sf.reshape(-1, tensor_sf.shape[-1])
global_scale = global_scale[:, None, None]
else:
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, torch.float32)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_f32 = tensor_f32.reshape(-1, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
if swizzle:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf = convert_swizzled_to_linear( # noqa: E501
tensor_sf, tensor_f32.size(0), k, block_size
)
if is_3d:
tensor_sf = tensor_sf.reshape(dim0, m, k // block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) * global_scale
if is_3d:
tensor_f32 = tensor_f32.reshape(dim0, m, -1, block_size)
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
out = tensor_f32 * tensor_sf_dtype.unsqueeze(-1)
out = out.reshape(*out.shape[:-2], -1)
return out.to(dtype)
......@@ -117,6 +143,28 @@ def ref_nvfp4_quant(x, global_scale, block_size):
return cast_to_fp4(clipped_x), scale.squeeze(-1)
def ref_nvfp4_quant_dequant(
x: torch.Tensor, global_scale: torch.Tensor, block_size: int
) -> tuple[torch.Tensor, None]:
"""
NVFP4 quantize-dequantize operation.
`global_scale` is expected to have a single element.
"""
x_m, x_k = x.shape
output_dtype = x.dtype
# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, global_scale, block_size)
# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // block_size, block_size)
x_blockscale = x_blockscale.unsqueeze(-1) / global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
return x_dq, None
def run_nvfp4_emulations(
x: torch.Tensor,
input_global_scale: torch.Tensor,
......@@ -125,18 +173,10 @@ def run_nvfp4_emulations(
weight_global_scale: torch.Tensor,
swizzle: bool | None = True,
):
group_size = 16
x_m, x_k = x.shape
output_dtype = x.dtype
group_size = 16
# quantize input to (FP4 and interleaved block scale)
x_fp4, x_blockscale = ref_nvfp4_quant(x, input_global_scale, group_size)
# dequantize input
x_fp4 = x_fp4.reshape(x_m, x_k // group_size, group_size)
x_blockscale = x_blockscale.unsqueeze(-1) / input_global_scale
x_dq = (x_fp4 * x_blockscale).reshape(x_m, x_k).to(output_dtype)
del x_fp4, x_blockscale
x_dq, _ = ref_nvfp4_quant_dequant(x, input_global_scale, block_size=group_size)
# dequantize weight
w_fp4 = weight.data.view(torch.uint8)
......@@ -151,5 +191,4 @@ def run_nvfp4_emulations(
# matmul
out = torch.matmul(x_dq, w_dq.t())
del w_dq, x_dq
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment