Unverified Commit 38fa87ca authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

mxfp8 online quant move to new frontend (#40152)


Signed-off-by: default avatarVasiliy Kuznetsov <vasiliy@meta.com>
parent a023edfa
......@@ -17,6 +17,9 @@ llm = LLM("meta-llama/Llama-3.1-8B", quantization="fp8_per_tensor")
# Per-block FP8 quantization (128x128 block scaling for weights and 1x128 block scaling for activations)
llm = LLM("meta-llama/Llama-3.1-8B", quantization="fp8_per_block")
# MXFP8 quantization for weights and activations
llm = LLM("meta-llama/Llama-3.1-8B", quantization="mxfp8")
```
Or with the CLI:
......@@ -24,6 +27,7 @@ Or with the CLI:
```bash
vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_tensor
vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_block
vllm serve meta-llama/Llama-3.1-8B --quantization mxfp8
```
## Supported Schemes
......@@ -32,8 +36,7 @@ vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_block
| ------ | ------------- | ------------------ | ----- |
| `fp8_per_tensor` | fp8_e4m3 data, fp32 per-tensor scale | fp8_e4m3 data, fp32 per-tensor scale | On some GPUs (Ada, Hopper) linear activations use per-token scaling for better performance |
| `fp8_per_block` | fp8_e4m3 data, fp32 per-128x128-block scale | fp8_e4m3 data, fp32 per-1x128-block scale | |
Support for additional schemes will be added in future versions of vllm.
| `mxfp8` | fp8_e4m3 data, e8m0 per-1x32-block scale | fp8_e4m3 data, e8m0 per-1x32-block scale | Requires SM 100+ (Blackwell or newer) for w8a8, other GPUs use a w8a16 fallback |
## Advanced Configuration
......
......@@ -23,7 +23,8 @@ class OnlineQuantScheme(Enum):
# Linear layers remain unquantized.
INT8_PER_CHANNEL_WEIGHT_ONLY = "int8_per_channel_weight_only"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc
# mxfp8, weights scaled in blocks of 1x32 elements (microscaling FP8)
MXFP8 = "mxfp8"
@config
......
......@@ -31,7 +31,6 @@ QuantizationMethods = Literal[
"inc",
"mxfp4",
"gpt_oss_mxfp4",
"mxfp8",
"cpu_awq",
"online",
# Below are values of the OnlineQuantScheme enum, specified as strings to
......@@ -41,6 +40,7 @@ QuantizationMethods = Literal[
"fp8_per_tensor",
"fp8_per_block",
"int8_per_channel_weight_only",
"mxfp8",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
......@@ -135,7 +135,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
)
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import GptOssMxfp4Config, Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig
from .torchao import TorchAOConfig
......@@ -162,7 +161,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config,
"mxfp8": Mxfp8Config,
"cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig,
}
......
......@@ -515,14 +515,6 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
initialize_online_processing(layer)
# TODO: remove this check once the following RFC is resolved.
# https://github.com/vllm-project/vllm/issues/33314
# Subclasses (e.g. Mxfp8OnlineLinearMethod) only need the weight
# registration above and manage their own kernel, so skip fp8_linear
# kernel creation for them.
if type(self) is not Fp8OnlineLinearMethod:
return
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
......
......@@ -37,6 +37,10 @@ from vllm.model_executor.layers.quantization.online.fp8 import (
from vllm.model_executor.layers.quantization.online.int8 import (
Int8OnlineMoEMethod,
)
from vllm.model_executor.layers.quantization.online.mxfp8 import (
Mxfp8OnlineLinearMethod,
Mxfp8OnlineMoEMethod,
)
logger = init_logger(__name__)
......@@ -110,6 +114,8 @@ class OnlineQuantizationConfig(QuantizationConfig):
return UnquantizedLinearMethod()
elif linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod()
elif linear_scheme == OnlineQuantScheme.MXFP8:
return Mxfp8OnlineLinearMethod()
else:
return Fp8PerTensorOnlineLinearMethod()
elif isinstance(layer, FusedMoE):
......@@ -125,6 +131,8 @@ class OnlineQuantizationConfig(QuantizationConfig):
return Int8OnlineMoEMethod(layer=layer)
elif moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer)
elif moe_scheme == OnlineQuantScheme.MXFP8:
return Mxfp8OnlineMoEMethod(layer=layer)
else:
return Fp8PerTensorOnlineMoEMethod(layer=layer)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods."""
"""Online MXFP8 (microscaling FP8, block-32) quantization methods."""
from typing import Any
from typing import TYPE_CHECKING
import torch
from torch.nn import Module
from vllm.logger import init_logger
if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
from vllm.model_executor.kernels.linear import init_mxfp8_linear_kernel
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
select_mxfp8_moe_backend,
)
from vllm.model_executor.layers.linear import (
LinearBase,
UnquantizedLinearMethod,
from vllm.model_executor.layers.quantization.online.fp8 import (
_Fp8OnlineLinearBase,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8OnlineLinearMethod,
Fp8OnlineMoEMethod,
from vllm.model_executor.layers.quantization.online.moe_base import (
OnlineMoEMethodBase,
)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
mxfp8_e4m3_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
logger = init_logger(__name__)
class Mxfp8Config(Fp8Config):
"""Config class for online MXFP8 MoE quantization."""
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
) -> None:
if activation_scheme != "dynamic":
raise ValueError("mxfp8 only supports dynamic activation scheme.")
super().__init__(
is_checkpoint_fp8_serialized=False,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=None,
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp8"
@classmethod
def get_min_capability(cls) -> int:
# Marlin kernel supports MXFP8 on SM80+
return 80
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
activation_scheme = cls.get_from_keys_or(
config, ["activation_scheme"], "dynamic"
)
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedLinearMethod()
return Mxfp8OnlineLinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Mxfp8OnlineMoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
class Mxfp8OnlineLinearMethod(_Fp8OnlineLinearBase):
"""Online MXFP8 linear method.
Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling
FP8 with block-32 scales) during weight loading.
Args:
quant_config: The MXFP8 quantization config.
"""
uses_meta_device: bool = True
def __init__(self, quant_config: "Mxfp8Config"):
self.quant_config = quant_config
def __init__(self):
super().__init__()
self.kernel = init_mxfp8_linear_kernel()
def create_weights(
......@@ -178,19 +94,15 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
return self.kernel.apply_weights(layer, x, bias)
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
class Mxfp8OnlineMoEMethod(OnlineMoEMethodBase):
"""MoE method for online MXFP8 (block) quantization."""
uses_meta_device: bool = True
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
FusedMoEMethodBase.__init__(self, layer.moe_config)
self.quant_config = quant_config
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
fp8_backend: "Fp8MoeBackend"
experts_cls: "type[mk.FusedMoEExperts] | None"
self.weight_block_size = [1, MXFP8_BLOCK_SIZE]
self.block_quant = True
def __init__(self, *, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.weight_block_size: list[int] = [1, MXFP8_BLOCK_SIZE]
self.weight_scale_name = "weight_scale"
self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)
......@@ -247,6 +159,74 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
return w_quant, w_scales
def _setup_kernel(
self,
layer: "FusedMoE",
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig":
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
quant_config = make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
self._maybe_inject_biases(quant_config, layer)
return quant_config
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
......
......@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.mxfp8 import Mxfp8OnlineLinearMethod
from vllm.model_executor.layers.quantization.online.mxfp8 import Mxfp8OnlineLinearMethod
from vllm.tracing import instrument
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment