Unverified Commit df1e30e7 authored by EdalatiAli's avatar EdalatiAli Committed by GitHub
Browse files

[Quant] add CompressedTensorsW8A8Mxfp8 for linear and MoE layers (#38815)


Signed-off-by: default avatarEdalatiAli <aliedalati@cohere.com>
parent bd8bd523
......@@ -28,6 +28,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
)
......@@ -632,3 +633,38 @@ def test_get_quant_method_returns_none_for_unmatched_parallel_lm_head():
assert method is None, (
f"Expected None for unmatched ParallelLMHead, got {type(method).__name__}"
)
@pytest.mark.skipif(
not current_platform.is_cuda() or not current_platform.has_device_capability(75),
reason="MXFP8 requires Turing (sm_75+) or newer.",
)
def test_compressed_tensors_mxfp8_moe_setup(vllm_runner):
"""Verify MXFP8 scheme, dtypes, and generation for a MoE model."""
model_path = "AliEdalati97/Qwen3-30B-A3B-MXFP8"
with vllm_runner(
model_path,
enforce_eager=True,
load_format="dummy",
hf_overrides={"num_hidden_layers": 4},
) as llm:
def check_model(model):
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_mxfp8 import ( # noqa: E501
CompressedTensorsW8A8Mxfp8MoEMethod,
)
layer = model.model.layers[0]
qkv = layer.self_attn.qkv_proj
assert isinstance(qkv.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv.scheme, CompressedTensorsW8A8Mxfp8)
experts = layer.mlp.experts
assert isinstance(experts, FusedMoE)
assert isinstance(experts.quant_method, CompressedTensorsW8A8Mxfp8MoEMethod)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
assert output
......@@ -49,6 +49,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW4A16Mxfp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A8Mxfp8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
)
......@@ -403,6 +404,27 @@ class CompressedTensorsConfig(QuantizationConfig):
and is_symmetric
)
@staticmethod
def _is_mxfp8(quant_args: QuantizationArgs) -> bool:
if quant_args is None:
return False
is_group_quant = quant_args.strategy == QuantizationStrategy.GROUP.value
is_symmetric = quant_args.symmetric
is_group_size_32 = quant_args.group_size == 32
is_float_type = quant_args.type == QuantizationType.FLOAT
is_8_bits = quant_args.num_bits == 8
is_mxfp8_scale_dtype = quant_args.scale_dtype == torch.uint8
return (
is_group_quant
and is_float_type
and is_8_bits
and is_group_size_32
and is_symmetric
and is_mxfp8_scale_dtype
)
@staticmethod
def _is_static_tensor_w8a8(
weight_quant: QuantizationArgs, input_quant: QuantizationArgs
......@@ -606,6 +628,9 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_mxfp4(weight_quant):
return CompressedTensorsW4A16Mxfp4()
if self._is_mxfp8(weight_quant):
return CompressedTensorsW8A8Mxfp8()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8(
num_bits=weight_quant.num_bits,
......
......@@ -68,6 +68,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsW4A4Mxfp4MoEMethod(layer.moe_config)
if quant_config._is_mxfp8(weight_quant):
from .compressed_tensors_moe_w8a8_mxfp8 import (
CompressedTensorsW8A8Mxfp8MoEMethod,
)
return CompressedTensorsW8A8Mxfp8MoEMethod(layer.moe_config)
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
# group_size=None means channelwise
group_size = weight_quant.group_size or -1
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
select_mxfp8_moe_backend,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE,
MXFP8_VALUE_DTYPE,
)
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
class CompressedTensorsW8A8Mxfp8MoEMethod(CompressedTensorsMoEMethod):
"""Compressed-tensors MoE method for pre-quantized MXFP8 (W8A8) checkpoints.
Loads FP8 (E4M3) weights with E8M0 uint8 per-group scales (group_size=32)
from checkpoint. Activations are dynamically quantized to MXFP8 at runtime.
Supports FlashInfer TRT-LLM and Marlin backends (auto-selected).
"""
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
self.weight_block_size = [1, MXFP8_BLOCK_SIZE]
self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.params_dtype = params_dtype
w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
hidden_size,
dtype=MXFP8_VALUE_DTYPE,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=MXFP8_VALUE_DTYPE,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
w13_num_shards * intermediate_size_per_partition,
hidden_size // MXFP8_BLOCK_SIZE,
dtype=MXFP8_SCALE_DTYPE,
),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=MXFP8_SCALE_DTYPE,
),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: FusedMoE) -> None:
layer.weight_block_size = self.weight_block_size
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=layer.w13_weight,
w2=layer.w2_weight,
w13_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w13_input_scale=layer.w13_input_scale,
w2_input_scale=layer.w2_input_scale,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_weight_scale", w13_scale)
replace_parameter(layer, "w2_weight_scale", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.weight_block_size,
)
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
......@@ -9,6 +9,7 @@ from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a8_mxfp8 import CompressedTensorsW8A8Mxfp8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
......@@ -28,4 +29,5 @@ __all__ = [
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8",
"CompressedTensorsW8A8Mxfp8",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.model_executor.kernels.linear import init_mxfp8_linear_kernel
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE,
MXFP8_VALUE_DTYPE,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
)
__all__ = ["CompressedTensorsW8A8Mxfp8"]
class CompressedTensorsW8A8Mxfp8(CompressedTensorsScheme):
"""
Compressed tensors scheme for MXFP8 quantization (W8A8).
Loads pre-quantized MXFP8 weights from compressed-tensors checkpoints.
Activations are dynamically quantized to MXFP8 at runtime.
MXFP8 format:
- 8-bit float weights (E4M3) stored as float8_e4m3fn
- Per-group E8M0 scales (uint8) with group_size=32
- Activations dynamically quantized to MXFP8 during inference
"""
def __init__(self):
self.kernel = init_mxfp8_linear_kernel()
@classmethod
def get_min_capability(cls) -> int:
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.params_dtype = params_dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=MXFP8_VALUE_DTYPE,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // MXFP8_BLOCK_SIZE,
dtype=MXFP8_SCALE_DTYPE,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment