Unverified Commit b129136c authored by xuebwang-amd's avatar xuebwang-amd Committed by GitHub
Browse files

[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations (#29008)


Signed-off-by: default avatarxuebwang-amd <xuebwang@amd.com>
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 599e4335
...@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor ...@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward, triton_kernel_moe_forward,
) )
...@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): ...@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2, pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
quant_config = FusedMoEQuantConfig.make( if a_dtype == "bf16" and w_dtype == "mx4":
w1_bias=w1_bias_tri, quant_config = mxfp4_w4a16_moe_quant_config(
w2_bias=w2_bias_tri, w1_scale=pc1,
w1_scale=pc1, w2_scale=pc2,
w2_scale=pc2, w1_bias=w1_bias_tri,
) w2_bias=w2_bias_tri,
)
else:
raise NotImplementedError(
f"Quantization configuration for activation={a_dtype} and weight={w_dtype} "
f"has not been implemented."
)
out_triton_monolithic = triton_kernel_moe_forward( out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri, hidden_states=x_tri,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test attention quantization of gpt-oss model. """
The qkv_proj and o_proj in self_attention can be either quantized or excluded. End-to-end accuracy test for GPT-OSS model quantization.
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`. Config:
Task: gsm8k_platinum
Filter: flexible-extract
n-shot: 5
Metric: exact_match
Run: pytest tests/models/quantization/test_gpt_oss.py
""" """
import importlib import importlib
...@@ -16,11 +21,18 @@ import lm_eval ...@@ -16,11 +21,18 @@ import lm_eval
import pytest import pytest
from packaging import version from packaging import version
MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"] MODEL_ACCURACIES = {
# Full quantization: attention linears and MoE linears
"amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89,
# MoE linears only quantization
"amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8": 0.89,
# MoE linears only quantization
# "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-MXFP4-KV-FP8": 0.90,
}
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark") importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99") ) >= version.parse("0.9.0")
def has_huggingface_access(repo): def has_huggingface_access(repo):
...@@ -32,7 +44,7 @@ def has_huggingface_access(repo): ...@@ -32,7 +44,7 @@ def has_huggingface_access(repo):
HF_HUB_AMD_ORG_ACCESS = all( HF_HUB_AMD_ORG_ACCESS = all(
[has_huggingface_access(model_name) for model_name in MODEL_NAMES] [has_huggingface_access(model_name) for model_name in MODEL_ACCURACIES]
) )
...@@ -46,14 +58,19 @@ class ModelCase: ...@@ -46,14 +58,19 @@ class ModelCase:
class EvaluationConfig: class EvaluationConfig:
model_name: str model_name: str
def get_model_args(self) -> str: def get_model_args(self, tp_size: int):
return ( return {
f"pretrained={self.model_name}," "pretrained": self.model_name,
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False" "chat_template_args": {"reasoning_effort": "low"},
) "enable_thinking": True,
"think_end_token": "200008",
"tensor_parallel_size": tp_size,
EXPECTED_ACCURACIES = {"arc_challenge": 0.20} "dtype": "auto",
"gpu_memory_utilization": 0.95,
"trust_remote_code": False,
"enable_prefix_caching": False,
"enforce_eager": False,
}
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
...@@ -61,19 +78,32 @@ EXPECTED_ACCURACIES = {"arc_challenge": 0.20} ...@@ -61,19 +78,32 @@ EXPECTED_ACCURACIES = {"arc_challenge": 0.20}
not HF_HUB_AMD_ORG_ACCESS, not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.", reason="Read access to huggingface.co/amd is required for this test.",
) )
@pytest.mark.parametrize("model_name", MODEL_NAMES) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items()) @pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items())
def test_gpt_oss_attention_quantization( def test_gpt_oss_attention_quantization(
model_name: str, task_name: str, expected_accuracy: float model_name: str, tp_size: int, expected_accuracy: float
): ):
measured_accuracy = lm_eval.simple_evaluate( model_args = EvaluationConfig(model_name).get_model_args(tp_size)
extra_run_kwargs = {
"gen_kwargs": {"max_gen_toks": 8000},
"apply_chat_template": True,
"fewshot_as_multiturn": True,
"num_fewshot": 5,
}
lm_eval_out = lm_eval.simple_evaluate(
model="vllm", model="vllm",
model_args=EvaluationConfig(model_name).get_model_args(), model_args=model_args,
tasks=task_name, tasks="gsm8k_platinum",
batch_size="auto", batch_size="auto",
)["results"][task_name]["acc,none"] **extra_run_kwargs,
)
measured_accuracy = float(
lm_eval_out["results"]["gsm8k_platinum"]["exact_match,flexible-extract"]
)
rtol = 0.05 rtol = 0.02
assert ( assert (
measured_accuracy - rtol < expected_accuracy measured_accuracy - rtol < expected_accuracy
and measured_accuracy + rtol > expected_accuracy and measured_accuracy + rtol > expected_accuracy
......
...@@ -386,6 +386,10 @@ class FusedMoEQuantConfig: ...@@ -386,6 +386,10 @@ class FusedMoEQuantConfig:
def use_nvfp4_w4a4(self) -> bool: def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4" return self.quant_dtype == "nvfp4"
@property
def use_mxfp4_w4a8(self) -> bool:
return self._a1.dtype == "fp8" and self._w1.dtype == "mxfp4"
def config_name(self, dtype: torch.dtype) -> str | None: def config_name(self, dtype: torch.dtype) -> str | None:
""" """
Return a string used to construct the filename that contains the Return a string used to construct the filename that contains the
...@@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config( ...@@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
per_out_ch_quant: bool = False, per_out_ch_quant: bool = False,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
...@@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config( ...@@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config(
g1_alphas=g1_alphas, g1_alphas=g1_alphas,
w2_scale=w2_scale, w2_scale=w2_scale,
g2_alphas=g2_alphas, g2_alphas=g2_alphas,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_scale=a1_scale, a1_scale=a1_scale,
a1_gscale=a1_gscale, a1_gscale=a1_gscale,
a2_scale=a2_scale, a2_scale=a2_scale,
...@@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config( ...@@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, a2_scale: torch.Tensor | None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
...@@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config( ...@@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=None, block_shape=None,
...@@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config( ...@@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config(
) )
def mxfp4_w4a8_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc("fp8", None, a1_scale, None, None, None),
_a2=FusedMoEQuantDesc("fp8", None, a2_scale, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)
def ocp_mx_moe_quant_config( def ocp_mx_moe_quant_config(
quant_dtype: str, quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"], w1_scale: Union[torch.Tensor, "PrecisionConfig"],
...@@ -691,6 +723,8 @@ def nvfp4_moe_quant_config( ...@@ -691,6 +723,8 @@ def nvfp4_moe_quant_config(
a2_gscale: torch.Tensor, a2_gscale: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for mxfp4 activations and nvp4 weights. Construct a quant config for mxfp4 activations and nvp4 weights.
...@@ -699,6 +733,8 @@ def nvfp4_moe_quant_config( ...@@ -699,6 +733,8 @@ def nvfp4_moe_quant_config(
"nvfp4", "nvfp4",
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_gscale=a1_gscale, a1_gscale=a1_gscale,
a2_gscale=a2_gscale, a2_gscale=a2_gscale,
g1_alphas=g1_alphas, g1_alphas=g1_alphas,
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
) )
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
...@@ -1583,6 +1582,11 @@ def _get_config_quant_dtype( ...@@ -1583,6 +1582,11 @@ def _get_config_quant_dtype(
return "mxfp6_e3m2" return "mxfp6_e3m2"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
return "mxfp6_e2m3" return "mxfp6_e2m3"
elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
return torch.bfloat16
elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
return torch.float8_e4m3fn
return None return None
...@@ -1617,17 +1621,10 @@ def fused_experts_impl( ...@@ -1617,17 +1621,10 @@ def fused_experts_impl(
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None: elif ocp_mx_scheme is not None:
if ocp_mx_scheme in { if ocp_mx_scheme.startswith("w_mxfp4"):
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight # 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in { elif ocp_mx_scheme.startswith("w_mxfp6"):
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch" "hidden size mismatch"
) )
...@@ -1717,17 +1714,13 @@ def fused_experts_impl( ...@@ -1717,17 +1714,13 @@ def fused_experts_impl(
# TODO: On platforms for which `current_platform.supports_mx()` is True # TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel, # and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done. # this dequantization step should not be done.
if ocp_mx_scheme in { if ocp_mx_scheme.startswith("w_mxfp4"):
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation. # Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
w1 = dequant_mxfp6( w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
) )
...@@ -1736,7 +1729,7 @@ def fused_experts_impl( ...@@ -1736,7 +1729,7 @@ def fused_experts_impl(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
) )
w2_scale = None w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
w1 = dequant_mxfp6( w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
) )
...@@ -1779,6 +1772,7 @@ def fused_experts_impl( ...@@ -1779,6 +1772,7 @@ def fused_experts_impl(
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant, per_act_token_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
) )
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
...@@ -1846,6 +1840,7 @@ def fused_experts_impl( ...@@ -1846,6 +1840,7 @@ def fused_experts_impl(
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant, per_act_token_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
) )
if expert_map is not None: if expert_map is not None:
......
...@@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: ...@@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
) )
# TODO(rob): move this down to the kernel.
def maybe_roundup_hidden_size( def maybe_roundup_hidden_size(
hidden_size: int, hidden_size: int,
act_dtype: torch.dtype, act_dtype: torch.dtype,
quant_config: QuantizationConfig | None,
moe_parallel_config: FusedMoEParallelConfig, moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool, is_lora_enabled: bool,
model_type: str | None,
is_mxfp4_quant: bool,
) -> int: ) -> int:
""" """
Given layer hidden size and MoE configurations, round up hidden_size Given layer hidden size and MoE configurations, round up hidden_size
...@@ -235,11 +237,12 @@ def maybe_roundup_hidden_size( ...@@ -235,11 +237,12 @@ def maybe_roundup_hidden_size(
Args: Args:
hidden_size: Layer hidden-size hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations. act_dtype: Data type of the layer activations.
quant_config: Fused MoE quantization configuration.
moe_parallel_config: Fused MoE parallelization strategy configuration. moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the is used in the case of mxfp4 quantization in selecting the
MxFP4Backend. MxFP4Backend.
model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return: Return:
Rounded up hidden_size if rounding up is required based on the configs. Rounded up hidden_size if rounding up is required based on the configs.
...@@ -254,7 +257,7 @@ def maybe_roundup_hidden_size( ...@@ -254,7 +257,7 @@ def maybe_roundup_hidden_size(
) )
# we are padding globally so EP buffer allocation works # we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4": if model_type == "gpt_oss" and is_mxfp4_quant:
from vllm.model_executor.layers.quantization.mxfp4 import ( from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, Mxfp4Backend,
get_mxfp4_backend, get_mxfp4_backend,
...@@ -398,15 +401,6 @@ class FusedMoE(CustomOp): ...@@ -398,15 +401,6 @@ class FusedMoE(CustomOp):
# Expert mapping used in self.load_weights # Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping self.expert_mapping = expert_mapping
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(
hidden_size,
moe_in_dtype,
quant_config,
self.moe_parallel_config,
is_lora_enabled=self.vllm_config.lora_config is not None,
)
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
...@@ -508,7 +502,6 @@ class FusedMoE(CustomOp): ...@@ -508,7 +502,6 @@ class FusedMoE(CustomOp):
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
...@@ -548,6 +541,24 @@ class FusedMoE(CustomOp): ...@@ -548,6 +541,24 @@ class FusedMoE(CustomOp):
) )
self.routing_method_type: RoutingMethodType = self.router.routing_method_type self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=(
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
),
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
)
self.hidden_size = hidden_size
self.moe_config: FusedMoEConfig = FusedMoEConfig( self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
......
...@@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import ( ...@@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -241,7 +244,27 @@ def moe_kernel_quantize_input( ...@@ -241,7 +244,27 @@ def moe_kernel_quantize_input(
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
is_fp4_scale_swizzled: bool = True, is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
if ocp_mx_scheme is not None:
if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
pass # No QDQ needed for these schemes
elif ocp_mx_scheme.endswith("a_fp8"):
# Perform QDQ (quantize and dequantize) on activation for emulation
# purpose, because there is no native kernel for weight in ocp_mx_scheme
# and activation in FP8. The implementation is based on existing
# non-emulation ops.
qA, qA_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=False
)
A = per_tensor_dequantize(qA, qA_scale).to(A.dtype)
# After QDQ, we don't need further quantization
return A, None
# else: For other schemes (e.g., *_a_mxfp6_e3m2, *_a_mxfp6_e2m3),
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
......
...@@ -168,3 +168,19 @@ class QuantizationConfig(ABC): ...@@ -168,3 +168,19 @@ class QuantizationConfig(ABC):
Interface to update values after config initialization. Interface to update values after config initialization.
""" """
pass pass
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""
Determine if mxfp4 quantization will be used for this config.
This allows hidden_size rounding to happen before moe_config creation
without needing to instantiate quant_method first.
Args:
prefix: The layer prefix/name in the model
layer: The layer module
Returns:
True if this config uses MXFP4 quantization, False otherwise
"""
return False
...@@ -229,10 +229,15 @@ class Mxfp4Config(QuantizationConfig): ...@@ -229,10 +229,15 @@ class Mxfp4Config(QuantizationConfig):
) )
return None return None
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""MXFP4 config always uses MXFP4 quantization."""
return True
class Mxfp4MoEMethod(FusedMoEMethodBase): class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None self.marlin_input_dtype = None
......
...@@ -320,38 +320,45 @@ class QuarkConfig(QuantizationConfig): ...@@ -320,38 +320,45 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_ocp_mx( def _is_w_ocp_mx_a_x(
self, self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
weight_quant: dict[str, Any] | None,
input_quant: dict[str, Any] | None,
) -> bool: ) -> bool:
# Confirm weights and input quantized. """
if weight_quant is None or input_quant is None: This check returns True only if it is an OCP-MX weight quantization.
The activation can be any data type (e.g., FP16/BF16, FP8, or OCP-MX format).
The rationale for checking only the weight type is that
the model loading concept and process primarily concerns the weights themselves.
"""
# Confirm weights quantized.
if weight_quant is None:
logger.debug( logger.debug(
"Quark model is not in OCP MX format: " "Quark model's weight quantization is incompatible with OCP_MX format: "
"weight_quant or input_quant not set" "weight_quant is not set."
) )
return False return False
# Input and weight qscheme needs to be per group. # Input and weight qscheme needs to be per group.
if ( if weight_quant.get("qscheme") != "per_group":
weight_quant.get("qscheme") != "per_group" logger.debug(
or input_quant.get("qscheme") != "per_group" "Quark model's weight quantization is incompatible with OCP MX format: "
): "weight is not per_group."
logger.debug("Quark model is not in OCP MX format: not per_group") )
return False return False
# Input and weight group size needs to be 32. # Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: if weight_quant.get("group_size") != 32:
logger.debug("Quark model is not in OCP MX format: not group_size=32") logger.debug(
"Quark model's weight quantization is incompatible with OCP MX format: "
"group_size of weight is not 32."
)
return False return False
# Activations and weight scales need to be in e8m0 format. # Activations and weight scales need to be in e8m0 format.
if ( if weight_quant.get("scale_format") != "e8m0":
weight_quant.get("scale_format") != "e8m0" logger.debug(
or input_quant.get("scale_format") != "e8m0" "Quark model's weight quantization is incompatible with OCP MX format: "
): "scale_format of weight is not e8m0."
logger.debug("Quark model is not in OCP MX format: not scale_format e8m0") )
return False return False
# Input and weight dtypes need to be any of fp4, # Input and weight dtypes need to be any of fp4,
...@@ -360,14 +367,31 @@ class QuarkConfig(QuantizationConfig): ...@@ -360,14 +367,31 @@ class QuarkConfig(QuantizationConfig):
"fp4", "fp4",
"fp6_e3m2", "fp6_e3m2",
"fp6_e2m3", "fp6_e2m3",
} or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}: }:
logger.debug( logger.debug(
"Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3" "Quark model's weight quantization is incompatible with OCP MX format: "
"dtype is not in {fp4, fp6_e3m2, fp6_e2m3}."
) )
return False return False
return True return True
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""
For Quark, determine if it's OCP MXFP4 by checking config directly.
This allows hidden_size rounding to happen before moe_config creation.
"""
layer_quant_config = self._find_matched_config(prefix, layer)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
return (
self._is_w_ocp_mx_a_x(weight_config, input_config)
and weight_config is not None
and weight_config.get("dtype") == "fp4"
and getattr(torch, "float4_e2m1fn_x2", None) is not None
)
def _find_matched_config( def _find_matched_config(
self, layer_name: str, module: torch.nn.Module self, layer_name: str, module: torch.nn.Module
) -> dict[str, Any]: ) -> dict[str, Any]:
...@@ -441,7 +465,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -441,7 +465,7 @@ class QuarkConfig(QuantizationConfig):
is_static_input_scheme=True, is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"), input_symmetric=input_config.get("symmetric"),
) )
elif self._is_ocp_mx(weight_config, input_config): elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config) return QuarkOCP_MX(weight_config, input_config)
raise NotImplementedError( raise NotImplementedError(
......
...@@ -20,26 +20,44 @@ SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"} ...@@ -20,26 +20,44 @@ SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"}
class OCP_MX_Scheme(str, Enum): class OCP_MX_Scheme(str, Enum):
w_mxfp4 = "w_mxfp4"
w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4" w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4"
w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2" w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2"
w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3" w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3"
w_mxfp4_a_fp8 = "w_mxfp4_a_fp8"
w_mxfp6_e3m2 = "w_mxfp6_e3m2"
w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2" w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2"
w_mxfp6_e3m2_a_fp8 = "w_mxfp6_e3m2_a_fp8"
w_mxfp6_e2m3 = "w_mxfp6_e2m3"
w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3" w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3"
w_mxfp6_e2m3_a_fp8 = "w_mxfp6_e2m3_a_fp8"
@classmethod @classmethod
def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None): def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None):
if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES: if input_dtype not in OCP_MX_DTYPES and weight_dtype not in OCP_MX_DTYPES:
return None return None
elif input_dtype is None and weight_dtype == "mxfp4":
return cls.w_mxfp4
elif input_dtype is None and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2
elif input_dtype is None and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3
elif input_dtype == "mxfp4" and weight_dtype == "mxfp4": elif input_dtype == "mxfp4" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp4 return cls.w_mxfp4_a_mxfp4
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4": elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e3m2 return cls.w_mxfp4_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4": elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e2m3 return cls.w_mxfp4_a_mxfp6_e2m3
elif input_dtype == "fp8" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_fp8
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2": elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_mxfp6_e3m2 return cls.w_mxfp6_e3m2_a_mxfp6_e3m2
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_fp8
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3": elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_mxfp6_e2m3 return cls.w_mxfp6_e2m3_a_mxfp6_e2m3
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_fp8
else: else:
logger.warning( logger.warning(
"input_dtype='%s' and" "input_dtype='%s' and"
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment