Unverified Commit b129136c authored by xuebwang-amd's avatar xuebwang-amd Committed by GitHub
Browse files

[ROCm][Quantization] GPT_OSS in amd-quark format model loading and emulations (#29008)


Signed-off-by: default avatarxuebwang-amd <xuebwang@amd.com>
Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent 599e4335
...@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor ...@@ -22,7 +22,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout from triton_kernels.tensor_details import layout
from triton_kernels.testing import assert_close from triton_kernels.testing import assert_close
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import mxfp4_w4a16_moe_quant_config
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
triton_kernel_moe_forward, triton_kernel_moe_forward,
) )
...@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init): ...@@ -298,12 +298,18 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2, pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8) ) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
quant_config = FusedMoEQuantConfig.make( if a_dtype == "bf16" and w_dtype == "mx4":
w1_bias=w1_bias_tri, quant_config = mxfp4_w4a16_moe_quant_config(
w2_bias=w2_bias_tri, w1_scale=pc1,
w1_scale=pc1, w2_scale=pc2,
w2_scale=pc2, w1_bias=w1_bias_tri,
) w2_bias=w2_bias_tri,
)
else:
raise NotImplementedError(
f"Quantization configuration for activation={a_dtype} and weight={w_dtype} "
f"has not been implemented."
)
out_triton_monolithic = triton_kernel_moe_forward( out_triton_monolithic = triton_kernel_moe_forward(
hidden_states=x_tri, hidden_states=x_tri,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test attention quantization of gpt-oss model. """
The qkv_proj and o_proj in self_attention can be either quantized or excluded. End-to-end accuracy test for GPT-OSS model quantization.
Run `pytest tests/models/quantization/test_gpt_oss_attn_quantization.py`. Config:
Task: gsm8k_platinum
Filter: flexible-extract
n-shot: 5
Metric: exact_match
Run: pytest tests/models/quantization/test_gpt_oss.py
""" """
import importlib import importlib
...@@ -16,11 +21,18 @@ import lm_eval ...@@ -16,11 +21,18 @@ import lm_eval
import pytest import pytest
from packaging import version from packaging import version
MODEL_NAMES = ["amd/gpt-oss-20b-customized-attention-quantization"] MODEL_ACCURACIES = {
# Full quantization: attention linears and MoE linears
"amd/gpt-oss-20b-WFP8-AFP8-KVFP8": 0.89,
# MoE linears only quantization
"amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8": 0.89,
# MoE linears only quantization
# "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-MXFP4-KV-FP8": 0.90,
}
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse( QUARK_MXFP4_AVAILABLE = importlib.util.find_spec("quark") is not None and version.parse(
importlib.metadata.version("amd-quark") importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99") ) >= version.parse("0.9.0")
def has_huggingface_access(repo): def has_huggingface_access(repo):
...@@ -32,7 +44,7 @@ def has_huggingface_access(repo): ...@@ -32,7 +44,7 @@ def has_huggingface_access(repo):
HF_HUB_AMD_ORG_ACCESS = all( HF_HUB_AMD_ORG_ACCESS = all(
[has_huggingface_access(model_name) for model_name in MODEL_NAMES] [has_huggingface_access(model_name) for model_name in MODEL_ACCURACIES]
) )
...@@ -46,14 +58,19 @@ class ModelCase: ...@@ -46,14 +58,19 @@ class ModelCase:
class EvaluationConfig: class EvaluationConfig:
model_name: str model_name: str
def get_model_args(self) -> str: def get_model_args(self, tp_size: int):
return ( return {
f"pretrained={self.model_name}," "pretrained": self.model_name,
"tensor_parallel_size=4,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=False" "chat_template_args": {"reasoning_effort": "low"},
) "enable_thinking": True,
"think_end_token": "200008",
"tensor_parallel_size": tp_size,
EXPECTED_ACCURACIES = {"arc_challenge": 0.20} "dtype": "auto",
"gpu_memory_utilization": 0.95,
"trust_remote_code": False,
"enable_prefix_caching": False,
"enforce_eager": False,
}
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available") @pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
...@@ -61,19 +78,32 @@ EXPECTED_ACCURACIES = {"arc_challenge": 0.20} ...@@ -61,19 +78,32 @@ EXPECTED_ACCURACIES = {"arc_challenge": 0.20}
not HF_HUB_AMD_ORG_ACCESS, not HF_HUB_AMD_ORG_ACCESS,
reason="Read access to huggingface.co/amd is required for this test.", reason="Read access to huggingface.co/amd is required for this test.",
) )
@pytest.mark.parametrize("model_name", MODEL_NAMES) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("task_name, expected_accuracy", EXPECTED_ACCURACIES.items()) @pytest.mark.parametrize("model_name, expected_accuracy", MODEL_ACCURACIES.items())
def test_gpt_oss_attention_quantization( def test_gpt_oss_attention_quantization(
model_name: str, task_name: str, expected_accuracy: float model_name: str, tp_size: int, expected_accuracy: float
): ):
measured_accuracy = lm_eval.simple_evaluate( model_args = EvaluationConfig(model_name).get_model_args(tp_size)
extra_run_kwargs = {
"gen_kwargs": {"max_gen_toks": 8000},
"apply_chat_template": True,
"fewshot_as_multiturn": True,
"num_fewshot": 5,
}
lm_eval_out = lm_eval.simple_evaluate(
model="vllm", model="vllm",
model_args=EvaluationConfig(model_name).get_model_args(), model_args=model_args,
tasks=task_name, tasks="gsm8k_platinum",
batch_size="auto", batch_size="auto",
)["results"][task_name]["acc,none"] **extra_run_kwargs,
)
measured_accuracy = float(
lm_eval_out["results"]["gsm8k_platinum"]["exact_match,flexible-extract"]
)
rtol = 0.05 rtol = 0.02
assert ( assert (
measured_accuracy - rtol < expected_accuracy measured_accuracy - rtol < expected_accuracy
and measured_accuracy + rtol > expected_accuracy and measured_accuracy + rtol > expected_accuracy
......
...@@ -386,6 +386,10 @@ class FusedMoEQuantConfig: ...@@ -386,6 +386,10 @@ class FusedMoEQuantConfig:
def use_nvfp4_w4a4(self) -> bool: def use_nvfp4_w4a4(self) -> bool:
return self.quant_dtype == "nvfp4" return self.quant_dtype == "nvfp4"
@property
def use_mxfp4_w4a8(self) -> bool:
return self._a1.dtype == "fp8" and self._w1.dtype == "mxfp4"
def config_name(self, dtype: torch.dtype) -> str | None: def config_name(self, dtype: torch.dtype) -> str | None:
""" """
Return a string used to construct the filename that contains the Return a string used to construct the filename that contains the
...@@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config( ...@@ -532,6 +536,8 @@ def fp8_w8a8_moe_quant_config(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None = None, a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None, a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
per_out_ch_quant: bool = False, per_out_ch_quant: bool = False,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
...@@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config( ...@@ -549,6 +555,8 @@ def fp8_w8a8_moe_quant_config(
g1_alphas=g1_alphas, g1_alphas=g1_alphas,
w2_scale=w2_scale, w2_scale=w2_scale,
g2_alphas=g2_alphas, g2_alphas=g2_alphas,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_scale=a1_scale, a1_scale=a1_scale,
a1_gscale=a1_gscale, a1_gscale=a1_gscale,
a2_scale=a2_scale, a2_scale=a2_scale,
...@@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config( ...@@ -564,6 +572,8 @@ def int8_w8a8_moe_quant_config(
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None, a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None, a2_scale: torch.Tensor | None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
...@@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config( ...@@ -575,6 +585,8 @@ def int8_w8a8_moe_quant_config(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=None, block_shape=None,
...@@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config( ...@@ -654,6 +666,26 @@ def mxfp4_mxfp8_moe_quant_config(
) )
def mxfp4_w4a8_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for fp8 activations and mxfp4 weights.
"""
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc("fp8", None, a1_scale, None, None, None),
_a2=FusedMoEQuantDesc("fp8", None, a2_scale, None, None, None),
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
)
def ocp_mx_moe_quant_config( def ocp_mx_moe_quant_config(
quant_dtype: str, quant_dtype: str,
w1_scale: Union[torch.Tensor, "PrecisionConfig"], w1_scale: Union[torch.Tensor, "PrecisionConfig"],
...@@ -691,6 +723,8 @@ def nvfp4_moe_quant_config( ...@@ -691,6 +723,8 @@ def nvfp4_moe_quant_config(
a2_gscale: torch.Tensor, a2_gscale: torch.Tensor,
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> FusedMoEQuantConfig: ) -> FusedMoEQuantConfig:
""" """
Construct a quant config for mxfp4 activations and nvp4 weights. Construct a quant config for mxfp4 activations and nvp4 weights.
...@@ -699,6 +733,8 @@ def nvfp4_moe_quant_config( ...@@ -699,6 +733,8 @@ def nvfp4_moe_quant_config(
"nvfp4", "nvfp4",
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
a1_gscale=a1_gscale, a1_gscale=a1_gscale,
a2_gscale=a2_gscale, a2_gscale=a2_gscale,
g1_alphas=g1_alphas, g1_alphas=g1_alphas,
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
) )
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4 from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6 from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
...@@ -1583,6 +1582,11 @@ def _get_config_quant_dtype( ...@@ -1583,6 +1582,11 @@ def _get_config_quant_dtype(
return "mxfp6_e3m2" return "mxfp6_e3m2"
elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}: elif ocp_mx_scheme in {"w_mxfp4_a_mxfp6_e2m3", "w_mxfp6_e2m3_a_mxfp6_e2m3"}:
return "mxfp6_e2m3" return "mxfp6_e2m3"
elif ocp_mx_scheme in {"w_mxfp4", "w_mxfp6_e3m2", "w_mxfp6_e2m3"}:
return torch.bfloat16
elif ocp_mx_scheme in {"w_mxfp4_a_fp8", "w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"}:
return torch.float8_e4m3fn
return None return None
...@@ -1617,17 +1621,10 @@ def fused_experts_impl( ...@@ -1617,17 +1621,10 @@ def fused_experts_impl(
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch" assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None: elif ocp_mx_scheme is not None:
if ocp_mx_scheme in { if ocp_mx_scheme.startswith("w_mxfp4"):
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
"w_mxfp4_a_mxfp6_e2m3",
}:
# 16bit activation and fp4x2 packed weight # 16bit activation and fp4x2 packed weight
assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch" assert hidden_states.size(1) == w1.size(2) * 2, "hidden size mismatch"
elif ocp_mx_scheme in { elif ocp_mx_scheme.startswith("w_mxfp6"):
"w_mxfp6_e3m2_a_mxfp6_e3m2",
"w_mxfp6_e2m3_a_mxfp6_e2m3",
}:
assert hidden_states.size(1) == (w1.size(2) * 4) // 3, ( assert hidden_states.size(1) == (w1.size(2) * 4) // 3, (
"hidden size mismatch" "hidden size mismatch"
) )
...@@ -1717,17 +1714,13 @@ def fused_experts_impl( ...@@ -1717,17 +1714,13 @@ def fused_experts_impl(
# TODO: On platforms for which `current_platform.supports_mx()` is True # TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel, # and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done. # this dequantization step should not be done.
if ocp_mx_scheme in { if ocp_mx_scheme.startswith("w_mxfp4"):
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation. # Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2: elif ocp_mx_scheme.startswith("w_mxfp6_e3m2"):
w1 = dequant_mxfp6( w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
) )
...@@ -1736,7 +1729,7 @@ def fused_experts_impl( ...@@ -1736,7 +1729,7 @@ def fused_experts_impl(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
) )
w2_scale = None w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3: elif ocp_mx_scheme.startswith("w_mxfp6_e2m3"):
w1 = dequant_mxfp6( w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
) )
...@@ -1779,6 +1772,7 @@ def fused_experts_impl( ...@@ -1779,6 +1772,7 @@ def fused_experts_impl(
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant, per_act_token_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
) )
# SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
...@@ -1846,6 +1840,7 @@ def fused_experts_impl( ...@@ -1846,6 +1840,7 @@ def fused_experts_impl(
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=per_channel_quant, per_act_token_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
ocp_mx_scheme=ocp_mx_scheme,
) )
if expert_map is not None: if expert_map is not None:
......
...@@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str: ...@@ -221,12 +221,14 @@ def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
) )
# TODO(rob): move this down to the kernel.
def maybe_roundup_hidden_size( def maybe_roundup_hidden_size(
hidden_size: int, hidden_size: int,
act_dtype: torch.dtype, act_dtype: torch.dtype,
quant_config: QuantizationConfig | None,
moe_parallel_config: FusedMoEParallelConfig, moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool, is_lora_enabled: bool,
model_type: str | None,
is_mxfp4_quant: bool,
) -> int: ) -> int:
""" """
Given layer hidden size and MoE configurations, round up hidden_size Given layer hidden size and MoE configurations, round up hidden_size
...@@ -235,11 +237,12 @@ def maybe_roundup_hidden_size( ...@@ -235,11 +237,12 @@ def maybe_roundup_hidden_size(
Args: Args:
hidden_size: Layer hidden-size hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations. act_dtype: Data type of the layer activations.
quant_config: Fused MoE quantization configuration.
moe_parallel_config: Fused MoE parallelization strategy configuration. moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the is used in the case of mxfp4 quantization in selecting the
MxFP4Backend. MxFP4Backend.
model_type: for checking if gpt-oss
is_mxfp4_quant: whether the layer is quantized with mxfp4
Return: Return:
Rounded up hidden_size if rounding up is required based on the configs. Rounded up hidden_size if rounding up is required based on the configs.
...@@ -254,7 +257,7 @@ def maybe_roundup_hidden_size( ...@@ -254,7 +257,7 @@ def maybe_roundup_hidden_size(
) )
# we are padding globally so EP buffer allocation works # we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4": if model_type == "gpt_oss" and is_mxfp4_quant:
from vllm.model_executor.layers.quantization.mxfp4 import ( from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, Mxfp4Backend,
get_mxfp4_backend, get_mxfp4_backend,
...@@ -398,15 +401,6 @@ class FusedMoE(CustomOp): ...@@ -398,15 +401,6 @@ class FusedMoE(CustomOp):
# Expert mapping used in self.load_weights # Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping self.expert_mapping = expert_mapping
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(
hidden_size,
moe_in_dtype,
quant_config,
self.moe_parallel_config,
is_lora_enabled=self.vllm_config.lora_config is not None,
)
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
...@@ -508,7 +502,6 @@ class FusedMoE(CustomOp): ...@@ -508,7 +502,6 @@ class FusedMoE(CustomOp):
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.renormalize = renormalize self.renormalize = renormalize
...@@ -548,6 +541,24 @@ class FusedMoE(CustomOp): ...@@ -548,6 +541,24 @@ class FusedMoE(CustomOp):
) )
self.routing_method_type: RoutingMethodType = self.router.routing_method_type self.routing_method_type: RoutingMethodType = self.router.routing_method_type
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
model_type=(
self.vllm_config.model_config.hf_config.model_type
if self.vllm_config.model_config is not None
else None
),
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
)
self.hidden_size = hidden_size
self.moe_config: FusedMoEConfig = FusedMoEConfig( self.moe_config: FusedMoEConfig = FusedMoEConfig(
num_experts=self.global_num_experts, num_experts=self.global_num_experts,
experts_per_token=top_k, experts_per_token=top_k,
......
...@@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import ( ...@@ -23,6 +23,9 @@ from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -241,7 +244,27 @@ def moe_kernel_quantize_input( ...@@ -241,7 +244,27 @@ def moe_kernel_quantize_input(
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
is_fp4_scale_swizzled: bool = True, is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
if ocp_mx_scheme is not None:
if ocp_mx_scheme in {"w_mxfp4", "w_mxfp4_a_mxfp4"}:
pass # No QDQ needed for these schemes
elif ocp_mx_scheme.endswith("a_fp8"):
# Perform QDQ (quantize and dequantize) on activation for emulation
# purpose, because there is no native kernel for weight in ocp_mx_scheme
# and activation in FP8. The implementation is based on existing
# non-emulation ops.
qA, qA_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=False
)
A = per_tensor_dequantize(qA, qA_scale).to(A.dtype)
# After QDQ, we don't need further quantization
return A, None
# else: For other schemes (e.g., *_a_mxfp6_e3m2, *_a_mxfp6_e2m3),
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
......
...@@ -168,3 +168,19 @@ class QuantizationConfig(ABC): ...@@ -168,3 +168,19 @@ class QuantizationConfig(ABC):
Interface to update values after config initialization. Interface to update values after config initialization.
""" """
pass pass
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""
Determine if mxfp4 quantization will be used for this config.
This allows hidden_size rounding to happen before moe_config creation
without needing to instantiate quant_method first.
Args:
prefix: The layer prefix/name in the model
layer: The layer module
Returns:
True if this config uses MXFP4 quantization, False otherwise
"""
return False
...@@ -229,10 +229,15 @@ class Mxfp4Config(QuantizationConfig): ...@@ -229,10 +229,15 @@ class Mxfp4Config(QuantizationConfig):
) )
return None return None
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""MXFP4 config always uses MXFP4 quantization."""
return True
class Mxfp4MoEMethod(FusedMoEMethodBase): class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.weight_dtype = "mxfp4"
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.marlin_input_dtype = None self.marlin_input_dtype = None
......
...@@ -320,38 +320,45 @@ class QuarkConfig(QuantizationConfig): ...@@ -320,38 +320,45 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported. # Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _is_ocp_mx( def _is_w_ocp_mx_a_x(
self, self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
weight_quant: dict[str, Any] | None,
input_quant: dict[str, Any] | None,
) -> bool: ) -> bool:
# Confirm weights and input quantized. """
if weight_quant is None or input_quant is None: This check returns True only if it is an OCP-MX weight quantization.
The activation can be any data type (e.g., FP16/BF16, FP8, or OCP-MX format).
The rationale for checking only the weight type is that
the model loading concept and process primarily concerns the weights themselves.
"""
# Confirm weights quantized.
if weight_quant is None:
logger.debug( logger.debug(
"Quark model is not in OCP MX format: " "Quark model's weight quantization is incompatible with OCP_MX format: "
"weight_quant or input_quant not set" "weight_quant is not set."
) )
return False return False
# Input and weight qscheme needs to be per group. # Input and weight qscheme needs to be per group.
if ( if weight_quant.get("qscheme") != "per_group":
weight_quant.get("qscheme") != "per_group" logger.debug(
or input_quant.get("qscheme") != "per_group" "Quark model's weight quantization is incompatible with OCP MX format: "
): "weight is not per_group."
logger.debug("Quark model is not in OCP MX format: not per_group") )
return False return False
# Input and weight group size needs to be 32. # Input and weight group size needs to be 32.
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32: if weight_quant.get("group_size") != 32:
logger.debug("Quark model is not in OCP MX format: not group_size=32") logger.debug(
"Quark model's weight quantization is incompatible with OCP MX format: "
"group_size of weight is not 32."
)
return False return False
# Activations and weight scales need to be in e8m0 format. # Activations and weight scales need to be in e8m0 format.
if ( if weight_quant.get("scale_format") != "e8m0":
weight_quant.get("scale_format") != "e8m0" logger.debug(
or input_quant.get("scale_format") != "e8m0" "Quark model's weight quantization is incompatible with OCP MX format: "
): "scale_format of weight is not e8m0."
logger.debug("Quark model is not in OCP MX format: not scale_format e8m0") )
return False return False
# Input and weight dtypes need to be any of fp4, # Input and weight dtypes need to be any of fp4,
...@@ -360,14 +367,31 @@ class QuarkConfig(QuantizationConfig): ...@@ -360,14 +367,31 @@ class QuarkConfig(QuantizationConfig):
"fp4", "fp4",
"fp6_e3m2", "fp6_e3m2",
"fp6_e2m3", "fp6_e2m3",
} or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}: }:
logger.debug( logger.debug(
"Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3" "Quark model's weight quantization is incompatible with OCP MX format: "
"dtype is not in {fp4, fp6_e3m2, fp6_e2m3}."
) )
return False return False
return True return True
def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool:
"""
For Quark, determine if it's OCP MXFP4 by checking config directly.
This allows hidden_size rounding to happen before moe_config creation.
"""
layer_quant_config = self._find_matched_config(prefix, layer)
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
return (
self._is_w_ocp_mx_a_x(weight_config, input_config)
and weight_config is not None
and weight_config.get("dtype") == "fp4"
and getattr(torch, "float4_e2m1fn_x2", None) is not None
)
def _find_matched_config( def _find_matched_config(
self, layer_name: str, module: torch.nn.Module self, layer_name: str, module: torch.nn.Module
) -> dict[str, Any]: ) -> dict[str, Any]:
...@@ -441,7 +465,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -441,7 +465,7 @@ class QuarkConfig(QuantizationConfig):
is_static_input_scheme=True, is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"), input_symmetric=input_config.get("symmetric"),
) )
elif self._is_ocp_mx(weight_config, input_config): elif self._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX(weight_config, input_config) return QuarkOCP_MX(weight_config, input_config)
raise NotImplementedError( raise NotImplementedError(
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
...@@ -18,9 +19,15 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -18,9 +19,15 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config,
mxfp4_w4a8_moe_quant_config,
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config, ocp_mx_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
get_mxfp4_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin, prepare_fp8_moe_layer_for_marlin,
) )
...@@ -37,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -37,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils.math_utils import round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,6 +54,7 @@ __all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"] ...@@ -46,6 +54,7 @@ __all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
class QuarkMoEMethod(FusedMoEMethodBase): class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__(moe) super().__init__(moe)
self.has_bias = self.moe.has_bias
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
...@@ -67,7 +76,7 @@ class QuarkMoEMethod(FusedMoEMethodBase): ...@@ -67,7 +76,7 @@ class QuarkMoEMethod(FusedMoEMethodBase):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config): elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_ocp_mx(weight_config, input_config): elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
else: else:
raise RuntimeError("Unsupported FusedMoe scheme") raise RuntimeError("Unsupported FusedMoe scheme")
...@@ -86,6 +95,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -86,6 +95,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.weight_qscheme = self.weight_quant.get("qscheme") self.weight_qscheme = self.weight_quant.get("qscheme")
self.input_qscheme = self.input_quant.get("qscheme") self.input_qscheme = self.input_quant.get("qscheme")
self.weight_dtype = self.weight_quant.get("dtype", "").replace(
"fp8_e4m3", "fp8"
)
self.input_dtype = self.input_quant.get("dtype", "").replace("fp8_e4m3", "fp8")
per_tensor = ( per_tensor = (
self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor" self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor"
) )
...@@ -121,6 +134,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -121,6 +134,10 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
self.model_type = getattr(
get_current_vllm_config().model_config.hf_config, "model_type", None
)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -166,9 +183,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -166,9 +183,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
if self.weight_qscheme == "per_tensor": if self.weight_qscheme == "per_tensor":
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
# They are combined to a single scale after weight loading. # They are combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter( if self.model_type != "gpt_oss":
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False w13_weight_scale = torch.nn.Parameter(
) torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
else:
# For gpt_oss, the w1(gate) & w3(up) are fused as one.
# Therefore, only one weight scale for each expert.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 1, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False torch.ones(num_experts, dtype=torch.float32), requires_grad=False
...@@ -220,6 +244,27 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -220,6 +244,27 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
if self.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
else:
layer.w13_bias, layer.w2_bias = None, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
...@@ -278,21 +323,40 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -278,21 +323,40 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
assert layer.w13_weight_scale is not None assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0 # For gpt_oss, w1 and w3 are fused into a single combined
for shard_id in range(2): # gate_up_proj tensor with size 2*intermediate_size_per_partition
# and only one scale per expert.
# Process the entire weight tensor as one shard.
if self.model_type == "gpt_oss":
for expert_id in range(layer.local_num_experts):
# Process all 2*intermediate_size_per_partition rows at once
dq_weight = per_tensor_dequantize( dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight[expert_id],
layer.w13_weight_scale[expert_id][shard_id], layer.w13_weight_scale[expert_id][0],
) )
layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( layer.w13_weight[expert_id], _ = ops.scaled_fp8_quant(
ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) dq_weight, max_w13_scales[expert_id]
) )
start += shard_size else:
# For non-gpt_oss, process w1 and w3 shards separately
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
(
layer.w13_weight[expert_id][start : start + shard_size, :],
_,
) = ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False max_w13_scales, requires_grad=False
) )
# quark's scale is 1 dim. # quark's scale is 1 dim.
elif self.weight_qscheme == "per_channel": elif self.weight_qscheme == "per_channel":
if self.act_quant_group_shape == GroupShape.PER_TOKEN: if self.act_quant_group_shape == GroupShape.PER_TOKEN:
...@@ -343,6 +407,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): ...@@ -343,6 +407,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
per_act_token_quant=self.input_qscheme == "per_channel", per_act_token_quant=self.input_qscheme == "per_channel",
per_out_ch_quant=self.weight_qscheme == "per_channel", per_out_ch_quant=self.weight_qscheme == "per_channel",
) )
...@@ -563,7 +629,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -563,7 +629,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def __init__( def __init__(
self, self,
weight_config: dict[str, Any], weight_config: dict[str, Any],
input_config: dict[str, Any], input_config: dict[str, Any] | None,
moe: FusedMoEConfig, moe: FusedMoEConfig,
): ):
super().__init__(moe) super().__init__(moe)
...@@ -571,35 +637,79 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -571,35 +637,79 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
self.input_quant = input_config self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme") weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme") if not weight_qscheme == "per_group":
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
raise ValueError( raise ValueError(
"For MX(FP4) Fused MoE layers, only per-group scales " "For MX(FP4) Fused MoE layers, only per-group scales "
"for weights and activations are supported. Found " f"for weights are supported. Found {weight_qscheme}."
f"{weight_qscheme}, {input_qscheme}"
) # noqa E501 ) # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp")
self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") if self.input_quant is not None:
input_quant = self.input_quant["dtype"]
if input_quant in ["fp4", "fp6_e3m2", "fp6_e2m3"]:
self.input_dtype = input_quant.replace("fp", "mxfp")
elif input_quant == "fp8_e4m3":
self.input_dtype = input_quant.replace("fp8_e4m3", "fp8")
else:
raise NotImplementedError(
f"Current input dtype {input_quant} is not compatible \
with OCP MX (weight) MoE quantization. Please open an issue"
)
else:
self.input_dtype = None
self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None)
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
self.input_dtype, self.weight_dtype self.input_dtype, self.weight_dtype
) )
if self.static_input_scales: if self.ocp_mx_scheme is None:
raise ValueError(
f"Unsupported OCP MX dtype combination for MoE: "
f"input_dtype={self.input_dtype}, weight_dtype={self.weight_dtype}. "
f"Please check that the combination is supported in OCP_MX_Scheme."
)
self.mxfp4_backend: Mxfp4Backend | None = None
if self.ocp_mx_scheme == "w_mxfp4":
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
if self.input_quant is not None:
self.static_input_scales = not self.input_quant.get("is_dynamic")
else:
self.static_input_scales = False
if any(
self.ocp_mx_scheme.endswith(a_scheme)
for a_scheme in ["a_mxfp4", "a_mxfp6_e3m2", "a_mxfp6_e2m3"]
):
if self.static_input_scales:
raise NotImplementedError(
"QuarkOCP_MX_MoEMethod with static input scales is currently "
f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. "
"Please open an issue."
)
elif self.ocp_mx_scheme.endswith("a_fp8") and not self.static_input_scales:
raise NotImplementedError( raise NotImplementedError(
"QuarkOCP_MX_MoEMethod with static input scales is currently " "QuarkOCP_MX_MoEMethod with dynamic input scales is currently "
"not implemented. Please open an issue." f"not implemented for OCP MX scheme {self.ocp_mx_scheme}. "
"Please open an issue."
) )
self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled()
self.emulate = not current_platform.supports_mx() or not ( self.model_type = getattr(
self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" get_current_vllm_config().model_config.hf_config, "model_type", None
) )
self._emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
if self.emulate: if self.emulate:
logger.warning_once( logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, " f"The current mode (supports_mx={current_platform.supports_mx()}, "
...@@ -640,12 +750,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -640,12 +750,23 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
) )
params_dtype = torch.uint8 params_dtype = torch.uint8
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 256
)
else:
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 64
)
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition_after_pad,
self.get_packed_dim(hidden_size, self.weight_dtype), self.get_packed_dim(hidden_size, self.weight_dtype),
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -659,7 +780,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -659,7 +780,9 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
torch.empty( torch.empty(
num_experts, num_experts,
hidden_size, hidden_size,
self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype), self.get_packed_dim(
intermediate_size_per_partition_after_pad, self.weight_dtype
),
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -672,7 +795,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -672,7 +795,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
w13_weight_scale = torch.nn.Parameter( w13_weight_scale = torch.nn.Parameter(
torch.ones( torch.ones(
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition_after_pad,
hidden_size // OCP_MX_BLOCK_SIZE, hidden_size // OCP_MX_BLOCK_SIZE,
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -682,7 +805,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -682,7 +805,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
torch.ones( torch.ones(
num_experts, num_experts,
hidden_size, hidden_size,
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, intermediate_size_per_partition_after_pad // OCP_MX_BLOCK_SIZE,
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -693,8 +816,96 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -693,8 +816,96 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
if self.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition_after_pad,
dtype=torch.float32,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
else:
layer.w13_bias, layer.w2_bias = None, None
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.static_input_scales:
# firstly, process activations if fp8 static input
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
layer.w2_input_scale
):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
_, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fnuz),
torch.empty_like(
layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w13_input_scale,
)
_, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fnuz),
torch.empty_like(
layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w2_input_scale,
)
# Reset the parameter
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)
# secondly, process mxfp weights
if self.emulate: if self.emulate:
torch.cuda.empty_cache()
return return
from aiter.utility.fp4_utils import e8m0_shuffle from aiter.utility.fp4_utils import e8m0_shuffle
...@@ -725,15 +936,40 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -725,15 +936,40 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return ocp_mx_moe_quant_config( if self.ocp_mx_scheme == "w_mxfp4":
quant_dtype=self.input_dtype, return mxfp4_w4a16_moe_quant_config(
weight_dtype=self.weight_dtype, w1_scale=layer.w13_weight_scale,
w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale,
w2_scale=layer.w2_weight_scale, w1_bias=layer.w13_bias,
a1_scale=None, w2_bias=layer.w2_bias,
a2_scale=None, )
block_shape=None, elif self.ocp_mx_scheme == "w_mxfp4_a_fp8":
) return mxfp4_w4a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
block_shape=None,
)
elif self.ocp_mx_scheme in ["w_mxfp6_e3m2_a_fp8", "w_mxfp6_e2m3_a_fp8"]:
raise NotImplementedError(
"Currently there is no corresponding fused moe quant config configured "
f"in vLLM for OCP MX scheme {self.ocp_mx_scheme}. Please open an issue."
)
else:
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
def apply( def apply(
self, self,
...@@ -743,24 +979,34 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -743,24 +979,34 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate: if not self.emulate:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( if (
rocm_aiter_fused_experts, self.model_type == "gpt_oss"
) and self.mxfp4_backend == Mxfp4Backend.TRITON
):
raise NotImplementedError(
"Triton kernel implemented fused MoE for GPT_OSS model "
"in Quark(MoE) format is not integrated or provided yet."
)
out = rocm_aiter_fused_experts( else:
x, from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
layer.w13_weight, rocm_aiter_fused_experts,
layer.w2_weight, )
topk_weights=topk_weights,
topk_ids=topk_ids, return rocm_aiter_fused_experts(
activation=layer.activation, x,
quant_config=self.moe_quant_config, layer.w13_weight,
expert_map=layer.expert_map, layer.w2_weight,
) topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
)
else: else:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
out = fused_experts( return fused_experts(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
...@@ -773,5 +1019,3 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -773,5 +1019,3 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
) )
return out
...@@ -20,26 +20,44 @@ SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"} ...@@ -20,26 +20,44 @@ SUPPORTED_OCP_MX_DTYPES = {"mxfp4", "mxfp6_e3m2", "mxfp6_e2m3"}
class OCP_MX_Scheme(str, Enum): class OCP_MX_Scheme(str, Enum):
w_mxfp4 = "w_mxfp4"
w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4" w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4"
w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2" w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2"
w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3" w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3"
w_mxfp4_a_fp8 = "w_mxfp4_a_fp8"
w_mxfp6_e3m2 = "w_mxfp6_e3m2"
w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2" w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2"
w_mxfp6_e3m2_a_fp8 = "w_mxfp6_e3m2_a_fp8"
w_mxfp6_e2m3 = "w_mxfp6_e2m3"
w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3" w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3"
w_mxfp6_e2m3_a_fp8 = "w_mxfp6_e2m3_a_fp8"
@classmethod @classmethod
def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None): def from_quant_dtype(cls, input_dtype: str | None, weight_dtype: str | None):
if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES: if input_dtype not in OCP_MX_DTYPES and weight_dtype not in OCP_MX_DTYPES:
return None return None
elif input_dtype is None and weight_dtype == "mxfp4":
return cls.w_mxfp4
elif input_dtype is None and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2
elif input_dtype is None and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3
elif input_dtype == "mxfp4" and weight_dtype == "mxfp4": elif input_dtype == "mxfp4" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp4 return cls.w_mxfp4_a_mxfp4
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4": elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e3m2 return cls.w_mxfp4_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4": elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e2m3 return cls.w_mxfp4_a_mxfp6_e2m3
elif input_dtype == "fp8" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_fp8
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2": elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_mxfp6_e3m2 return cls.w_mxfp6_e3m2_a_mxfp6_e3m2
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_fp8
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3": elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_mxfp6_e2m3 return cls.w_mxfp6_e2m3_a_mxfp6_e2m3
elif input_dtype == "fp8" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_fp8
else: else:
logger.warning( logger.warning(
"input_dtype='%s' and" "input_dtype='%s' and"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable import typing
from collections.abc import Callable, Iterable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -25,13 +26,17 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -25,13 +26,17 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.utils import rocm_unquantized_gemm from vllm.model_executor.layers.utils import rocm_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -98,6 +103,7 @@ class OAIAttention(nn.Module): ...@@ -98,6 +103,7 @@ class OAIAttention(nn.Module):
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.num_attention_heads, total_num_heads=self.num_attention_heads,
total_num_kv_heads=self.num_key_value_heads, total_num_kv_heads=self.num_key_value_heads,
bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
) )
...@@ -105,6 +111,7 @@ class OAIAttention(nn.Module): ...@@ -105,6 +111,7 @@ class OAIAttention(nn.Module):
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
input_size=self.num_attention_heads * self.head_dim, input_size=self.num_attention_heads * self.head_dim,
output_size=self.hidden_size, output_size=self.hidden_size,
bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
...@@ -306,6 +313,19 @@ class GptOssModel(nn.Module): ...@@ -306,6 +313,19 @@ class GptOssModel(nn.Module):
return x, aux_hidden_states return x, aux_hidden_states
return x return x
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, weight scales, activation scales
# (param_name, weight_name, expert_id, shard_id)
# NOTE: this is only used for quark.
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_local_experts,
num_redundant_experts=0,
)
def _load_weights_mxfp4( def _load_weights_mxfp4(
self, self,
ep_rank_end: int, ep_rank_end: int,
...@@ -318,7 +338,6 @@ class GptOssModel(nn.Module): ...@@ -318,7 +338,6 @@ class GptOssModel(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
mxfp4_block = 32
use_ep = self.parallel_config.enable_expert_parallel use_ep = self.parallel_config.enable_expert_parallel
num_experts = self.config.num_local_experts num_experts = self.config.num_local_experts
...@@ -333,9 +352,11 @@ class GptOssModel(nn.Module): ...@@ -333,9 +352,11 @@ class GptOssModel(nn.Module):
) )
intermediate_size = self.config.intermediate_size intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size) per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
per_rank_intermediate_size = per_rank_intermediate_size_block * mxfp4_block per_rank_intermediate_size = (
per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
)
# Calculate common slicing bounds for current rank # Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_start = tp_rank * per_rank_intermediate_size
...@@ -370,7 +391,9 @@ class GptOssModel(nn.Module): ...@@ -370,7 +391,9 @@ class GptOssModel(nn.Module):
narrow_weight = weight[ep_rank_start:ep_rank_end, ...] narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else: else:
narrow_weight = weight[ narrow_weight = weight[
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block ...,
tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
// OCP_MX_BLOCK_SIZE,
] ]
param = params_dict[name] param = params_dict[name]
...@@ -495,6 +518,449 @@ class GptOssModel(nn.Module): ...@@ -495,6 +518,449 @@ class GptOssModel(nn.Module):
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
def _load_weights_quark(
self,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
num_experts = self.config.num_local_experts
if use_ep:
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
else:
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
def _get_moe_weight_dtype(layer_id: int = 0) -> str | None:
"""Helper function to get MoE quantization weight dtype.
Args:
layer_id: Layer index to check (default 0, as all layers should
have the same quantization method)
Returns:
Weight dtype string (e.g., "mxfp4", "fp8") or None if not available
"""
if hasattr(self.layers[layer_id].mlp.experts.quant_method, "weight_dtype"):
return self.layers[layer_id].mlp.experts.quant_method.weight_dtype
return None
intermediate_size = self.config.intermediate_size
moe_weight_dtype = _get_moe_weight_dtype(layer_id=0)
if moe_weight_dtype == "mxfp4":
# MXFP4 requires OCP_MX_BLOCK_SIZE alignment
intermediate_size_block = intermediate_size // OCP_MX_BLOCK_SIZE
per_rank_intermediate_size_block = cdiv(intermediate_size_block, tp_size)
per_rank_intermediate_size = (
per_rank_intermediate_size_block * OCP_MX_BLOCK_SIZE
)
else:
# FP8 and other formats don't need alignment
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if is_pp_missing_parameter(name, self):
continue
layer_id, expert_id, fused_name = None, None, None
moe_quant_method = None
if "experts" in name:
parts = name.split(".")
ids = [s for s in parts if s.isdigit()]
# for amd-quark format that each expert is seperated
# need to extract the parameter name with experts fused.
# example model: amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
if len(ids) == 2:
layer_id, expert_id = int(ids[0]), int(ids[-1])
parts.pop(len(parts) - 1 - parts[::-1].index(str(expert_id)))
fused_name = ".".join(parts)
# for openai mxfp4 format that all experts are combined
# no need to extract the parameter name with experts fused.
# models: openai/gpt-oss-20b, openai/gpt-oss-120b
elif len(ids) == 1:
layer_id, expert_id = int(ids[0]), None
fused_name = name
else:
raise NameError(
f"Layer {name} contains more than 2 numeric indices. This is "
"an unexpected condition. Please open an issue if encountered."
)
moe_quant_method = _get_moe_weight_dtype(layer_id=layer_id)
def kv_cache_scale_loader(
quant_config: QuantizationConfig,
name: str,
params_dict: dict[str, typing.Any],
weight: torch.Tensor,
default_weight_loader: Callable[..., None],
loaded_params: set[str],
) -> tuple[bool, set[str]]:
"""
Load KV cache output scales.
Returns:
Tuple of (bool, set):
- bool: True if KV-cache scale was loaded into loaded_params
- set: Updated set of loaded_params if True else the original set
"""
# load explicit cached KV output scale from quant_config
if quant_config is not None and (
scale_name := quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
if weight.numel() != 1:
raise ValueError(
f"KV cache scale '{scale_name}' is expected to be a "
f"scalar, but got a tensor of shape {weight.shape}."
)
# Ensure weight is a scalar before passing to loader.
weight_loader(param, weight.flatten()[0])
loaded_params.add(scale_name)
return True, loaded_params
return False, loaded_params
load_kv_cache_scale_completed, loaded_params = kv_cache_scale_loader(
self.quant_config,
name,
params_dict,
loaded_weight,
default_weight_loader,
loaded_params,
)
if load_kv_cache_scale_completed:
continue
if (
all(key in name for key in ["input_scale", "mlp.experts"])
and expert_id is not None
):
assert loaded_weight.numel() == 1
expert_data = params_dict[fused_name].data[expert_id]
expert_data.copy_(loaded_weight)
loaded_params.add(fused_name)
continue
# Unified handler for mxfp4 weights and scales
elif moe_quant_method == "mxfp4" and any(
name.endswith(suffix)
for suffix in [
".w13_weight_scale",
".w2_weight_scale",
".w13_weight",
".w2_weight",
]
):
is_w13 = ".w13_" in name
is_scale = "_scale" in name
# Reshape weight for mxfp4 if needed (not for scales)
if not is_scale and expert_id is None:
if is_w13:
if loaded_weight.dim() < 3:
raise ValueError(
f"Expected w13_weight to have at least 3 "
f"dimensions, got shape "
f"{loaded_weight.shape}"
)
if loaded_weight.shape[0] != num_experts:
raise ValueError(
f"Expected w13_weight first dimension to be "
f"{num_experts}, got "
f"{loaded_weight.shape[0]}"
)
loaded_weight = loaded_weight.view(
num_experts, 2 * intermediate_size, -1
).contiguous()
else:
if loaded_weight.dim() < 3:
raise ValueError(
f"Expected w2_weight to have at least 3 "
f"dimensions, got shape "
f"{loaded_weight.shape}"
)
if loaded_weight.shape[0] != num_experts:
raise ValueError(
f"Expected w2_weight first dimension to be "
f"{num_experts}, got "
f"{loaded_weight.shape[0]}"
)
loaded_weight = loaded_weight.view(
num_experts, -1, intermediate_size // 2
).contiguous()
if use_ep:
sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
else:
if is_w13:
if expert_id is None:
sliced_weight = loaded_weight[
:, 2 * tp_rank_start : 2 * tp_rank_end, ...
]
else:
sliced_weight = loaded_weight[
2 * tp_rank_start : 2 * tp_rank_end, ...
]
else:
if is_scale:
sliced_weight = loaded_weight[
...,
tp_rank_start // OCP_MX_BLOCK_SIZE : tp_rank_end
// OCP_MX_BLOCK_SIZE,
]
else:
sliced_weight = loaded_weight[
..., tp_rank_start // 2 : tp_rank_end // 2
]
# NOTE(rob): because gpt-oss ckpt has "unique" structure with
# fused gate_up_proj fused on disk, we cannot use the existing
# weight loaders without added complexity, so just do the
# direct load here.
param = params_dict[fused_name]
expert_data = param.data[expert_id]
dim1 = sliced_weight.shape[0]
dim2 = sliced_weight.shape[1]
expert_data.data[:dim1, :dim2].copy_(sliced_weight)
loaded_params.add(fused_name)
continue
elif name.endswith(".w13_weight") and moe_quant_method == "fp8":
if use_ep:
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
else:
if expert_id is None:
narrow_weight = loaded_weight[
:, 2 * tp_rank_start : 2 * tp_rank_end, :
]
else:
narrow_weight = loaded_weight[
2 * tp_rank_start : 2 * tp_rank_end, :
]
assert fused_name is not None
param = params_dict[fused_name]
if expert_id is None:
param.data.copy_(narrow_weight)
else:
param.data[expert_id].copy_(narrow_weight)
loaded_params.add(fused_name)
continue
elif name.endswith(".w13_weight_scale") and moe_quant_method == "fp8":
assert fused_name is not None
param = params_dict[fused_name]
# Check if this is per-channel or per-tensor scale
if loaded_weight.numel() > 1 and loaded_weight.dim() == 1:
if use_ep:
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = loaded_weight[
2 * tp_rank_start : 2 * tp_rank_end
]
else:
narrow_weight = loaded_weight
if expert_id is None:
param.data.copy_(narrow_weight)
else:
param.data[expert_id].copy_(narrow_weight)
loaded_params.add(fused_name)
continue
elif name.endswith(".w13_input_scale") and moe_quant_method == "fp8":
assert fused_name is not None
param = params_dict[fused_name]
if expert_id is None:
param.data.copy_(loaded_weight)
else:
param.data[expert_id].copy_(loaded_weight)
loaded_params.add(fused_name)
continue
elif name.endswith(".w2_weight") and moe_quant_method == "fp8":
if use_ep:
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
else:
if expert_id is None:
narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
else:
narrow_weight = loaded_weight[..., tp_rank_start:tp_rank_end]
assert fused_name is not None
param = params_dict[fused_name]
if expert_id is None:
param.data.copy_(narrow_weight)
else:
param.data[expert_id].copy_(narrow_weight)
loaded_params.add(fused_name)
continue
elif name.endswith(".w2_weight_scale") and moe_quant_method == "fp8":
assert fused_name is not None
param = params_dict[fused_name]
if use_ep:
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = loaded_weight
if expert_id is None:
param.data.copy_(narrow_weight)
else:
param.data[expert_id].copy_(narrow_weight)
loaded_params.add(fused_name)
continue
# Unified handler for bias loading (w13_bias and w2_bias)
elif name.endswith(".w13_bias") or name.endswith(".w2_bias"):
is_w13_bias = name.endswith(".w13_bias")
if use_ep:
sliced_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
else:
if is_w13_bias:
if expert_id is None:
sliced_weight = loaded_weight[
:, 2 * tp_rank_start : 2 * tp_rank_end
]
else:
sliced_weight = loaded_weight[
2 * tp_rank_start : 2 * tp_rank_end
]
else:
sliced_weight = loaded_weight
if tp_rank != 0:
sliced_weight = sliced_weight.zero_()
# NOTE(rob): because gpt-oss ckpt has "unique" structure with
# fused gate_up_proj fused on disk, we cannot use the existing
# weight loaders without added complexity, so just do the
# direct load here.
assert fused_name is not None
param = params_dict[fused_name]
expert_data = param.data[expert_id]
dim1 = sliced_weight.shape[0]
expert_data.data[:dim1].copy_(sliced_weight)
loaded_params.add(fused_name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
param_name, weight_name, mapping_expert_id, shard_id = mapping
weight_name = (
weight_name[:-1] if weight_name.endswith(".") else weight_name
)
if weight_name not in name:
continue
param = params_dict[fused_name]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
# Use checkpoint's expert_id for quark format (when expert_id
# is extracted from weight name), otherwise use mapping's expert_id
actual_expert_id = (
expert_id if expert_id is not None else mapping_expert_id
)
success = weight_loader(
param,
loaded_weight,
fused_name,
shard_id=shard_id,
expert_id=actual_expert_id,
return_success=True,
)
if success:
name = fused_name
loaded_params.add(name)
break
else:
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _load_weights_other( def _load_weights_other(
self, self,
ep_rank_end: int, ep_rank_end: int,
...@@ -635,6 +1101,7 @@ class GptOssModel(nn.Module): ...@@ -635,6 +1101,7 @@ class GptOssModel(nn.Module):
if hasattr(self.config, "quantization_config") if hasattr(self.config, "quantization_config")
else None else None
) )
if quant_method == "mxfp4": if quant_method == "mxfp4":
return self._load_weights_mxfp4( return self._load_weights_mxfp4(
ep_rank_end, ep_rank_end,
...@@ -644,6 +1111,15 @@ class GptOssModel(nn.Module): ...@@ -644,6 +1111,15 @@ class GptOssModel(nn.Module):
weights, weights,
stacked_params_mapping, stacked_params_mapping,
) )
elif quant_method == "quark":
return self._load_weights_quark(
ep_rank_end,
ep_rank_start,
heads_per_rank,
head_start,
weights,
stacked_params_mapping,
)
else: else:
return self._load_weights_other( return self._load_weights_other(
ep_rank_end, ep_rank_end,
...@@ -676,6 +1152,15 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): ...@@ -676,6 +1152,15 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
# MoE Bias # MoE Bias
".gate_up_proj_bias": ".w13_bias", ".gate_up_proj_bias": ".w13_bias",
".down_proj_bias": ".w2_bias", ".down_proj_bias": ".w2_bias",
# For quark format
".gate_up_proj.weight": ".w13_weight",
".gate_up_proj.weight_scale": ".w13_weight_scale",
".gate_up_proj.bias": ".w13_bias",
".gate_up_proj.input_scale": ".w13_input_scale",
".down_proj.weight": ".w2_weight",
".down_proj.weight_scale": ".w2_weight_scale",
".down_proj.bias": ".w2_bias",
".down_proj.input_scale": ".w2_input_scale",
}, },
) )
...@@ -725,18 +1210,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA): ...@@ -725,18 +1210,6 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
logits = self.logits_processor(self.lm_head, hidden_states) logits = self.logits_processor(self.lm_head, hidden_states)
return logits return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, weight scales, activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_local_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment