Unverified Commit 38fa87ca authored by Vasiliy Kuznetsov's avatar Vasiliy Kuznetsov Committed by GitHub
Browse files

mxfp8 online quant move to new frontend (#40152)


Signed-off-by: default avatarVasiliy Kuznetsov <vasiliy@meta.com>
parent a023edfa
...@@ -17,6 +17,9 @@ llm = LLM("meta-llama/Llama-3.1-8B", quantization="fp8_per_tensor") ...@@ -17,6 +17,9 @@ llm = LLM("meta-llama/Llama-3.1-8B", quantization="fp8_per_tensor")
# Per-block FP8 quantization (128x128 block scaling for weights and 1x128 block scaling for activations) # Per-block FP8 quantization (128x128 block scaling for weights and 1x128 block scaling for activations)
llm = LLM("meta-llama/Llama-3.1-8B", quantization="fp8_per_block") llm = LLM("meta-llama/Llama-3.1-8B", quantization="fp8_per_block")
# MXFP8 quantization for weights and activations
llm = LLM("meta-llama/Llama-3.1-8B", quantization="mxfp8")
``` ```
Or with the CLI: Or with the CLI:
...@@ -24,6 +27,7 @@ Or with the CLI: ...@@ -24,6 +27,7 @@ Or with the CLI:
```bash ```bash
vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_tensor vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_tensor
vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_block vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_block
vllm serve meta-llama/Llama-3.1-8B --quantization mxfp8
``` ```
## Supported Schemes ## Supported Schemes
...@@ -32,8 +36,7 @@ vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_block ...@@ -32,8 +36,7 @@ vllm serve meta-llama/Llama-3.1-8B --quantization fp8_per_block
| ------ | ------------- | ------------------ | ----- | | ------ | ------------- | ------------------ | ----- |
| `fp8_per_tensor` | fp8_e4m3 data, fp32 per-tensor scale | fp8_e4m3 data, fp32 per-tensor scale | On some GPUs (Ada, Hopper) linear activations use per-token scaling for better performance | | `fp8_per_tensor` | fp8_e4m3 data, fp32 per-tensor scale | fp8_e4m3 data, fp32 per-tensor scale | On some GPUs (Ada, Hopper) linear activations use per-token scaling for better performance |
| `fp8_per_block` | fp8_e4m3 data, fp32 per-128x128-block scale | fp8_e4m3 data, fp32 per-1x128-block scale | | | `fp8_per_block` | fp8_e4m3 data, fp32 per-128x128-block scale | fp8_e4m3 data, fp32 per-1x128-block scale | |
| `mxfp8` | fp8_e4m3 data, e8m0 per-1x32-block scale | fp8_e4m3 data, e8m0 per-1x32-block scale | Requires SM 100+ (Blackwell or newer) for w8a8, other GPUs use a w8a16 fallback |
Support for additional schemes will be added in future versions of vllm.
## Advanced Configuration ## Advanced Configuration
......
...@@ -23,7 +23,8 @@ class OnlineQuantScheme(Enum): ...@@ -23,7 +23,8 @@ class OnlineQuantScheme(Enum):
# Linear layers remain unquantized. # Linear layers remain unquantized.
INT8_PER_CHANNEL_WEIGHT_ONLY = "int8_per_channel_weight_only" INT8_PER_CHANNEL_WEIGHT_ONLY = "int8_per_channel_weight_only"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc # mxfp8, weights scaled in blocks of 1x32 elements (microscaling FP8)
MXFP8 = "mxfp8"
@config @config
......
...@@ -31,7 +31,6 @@ QuantizationMethods = Literal[ ...@@ -31,7 +31,6 @@ QuantizationMethods = Literal[
"inc", "inc",
"mxfp4", "mxfp4",
"gpt_oss_mxfp4", "gpt_oss_mxfp4",
"mxfp8",
"cpu_awq", "cpu_awq",
"online", "online",
# Below are values of the OnlineQuantScheme enum, specified as strings to # Below are values of the OnlineQuantScheme enum, specified as strings to
...@@ -41,6 +40,7 @@ QuantizationMethods = Literal[ ...@@ -41,6 +40,7 @@ QuantizationMethods = Literal[
"fp8_per_tensor", "fp8_per_tensor",
"fp8_per_block", "fp8_per_block",
"int8_per_channel_weight_only", "int8_per_channel_weight_only",
"mxfp8",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -135,7 +135,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -135,7 +135,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
) )
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
from .mxfp4 import GptOssMxfp4Config, Mxfp4Config from .mxfp4 import GptOssMxfp4Config, Mxfp4Config
from .mxfp8 import Mxfp8Config
from .online.base import OnlineQuantizationConfig from .online.base import OnlineQuantizationConfig
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
...@@ -162,7 +161,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -162,7 +161,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"inc": INCConfig, "inc": INCConfig,
"mxfp4": Mxfp4Config, "mxfp4": Mxfp4Config,
"gpt_oss_mxfp4": GptOssMxfp4Config, "gpt_oss_mxfp4": GptOssMxfp4Config,
"mxfp8": Mxfp8Config,
"cpu_awq": CPUAWQConfig, "cpu_awq": CPUAWQConfig,
"online": OnlineQuantizationConfig, "online": OnlineQuantizationConfig,
} }
......
...@@ -515,14 +515,6 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -515,14 +515,6 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
initialize_online_processing(layer) initialize_online_processing(layer)
# TODO: remove this check once the following RFC is resolved.
# https://github.com/vllm-project/vllm/issues/33314
# Subclasses (e.g. Mxfp8OnlineLinearMethod) only need the weight
# registration above and manage their own kernel, so skip fp8_linear
# kernel creation for them.
if type(self) is not Fp8OnlineLinearMethod:
return
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key, activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key, weight_quant_key=self.weight_quant_key,
......
...@@ -37,6 +37,10 @@ from vllm.model_executor.layers.quantization.online.fp8 import ( ...@@ -37,6 +37,10 @@ from vllm.model_executor.layers.quantization.online.fp8 import (
from vllm.model_executor.layers.quantization.online.int8 import ( from vllm.model_executor.layers.quantization.online.int8 import (
Int8OnlineMoEMethod, Int8OnlineMoEMethod,
) )
from vllm.model_executor.layers.quantization.online.mxfp8 import (
Mxfp8OnlineLinearMethod,
Mxfp8OnlineMoEMethod,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -110,6 +114,8 @@ class OnlineQuantizationConfig(QuantizationConfig): ...@@ -110,6 +114,8 @@ class OnlineQuantizationConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK: elif linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod() return Fp8PerBlockOnlineLinearMethod()
elif linear_scheme == OnlineQuantScheme.MXFP8:
return Mxfp8OnlineLinearMethod()
else: else:
return Fp8PerTensorOnlineLinearMethod() return Fp8PerTensorOnlineLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
...@@ -125,6 +131,8 @@ class OnlineQuantizationConfig(QuantizationConfig): ...@@ -125,6 +131,8 @@ class OnlineQuantizationConfig(QuantizationConfig):
return Int8OnlineMoEMethod(layer=layer) return Int8OnlineMoEMethod(layer=layer)
elif moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK: elif moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer) return Fp8PerBlockOnlineMoEMethod(layer=layer)
elif moe_scheme == OnlineQuantScheme.MXFP8:
return Mxfp8OnlineMoEMethod(layer=layer)
else: else:
return Fp8PerTensorOnlineMoEMethod(layer=layer) return Fp8PerTensorOnlineMoEMethod(layer=layer)
return None return None
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Online MXFP8 (microscaling FP8, block-32) quantization config and methods.""" """Online MXFP8 (microscaling FP8, block-32) quantization methods."""
from typing import Any from typing import TYPE_CHECKING
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm.logger import init_logger if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
from vllm.model_executor.kernels.linear import init_mxfp8_linear_kernel from vllm.model_executor.kernels.linear import init_mxfp8_linear_kernel
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import (
select_mxfp8_moe_backend, select_mxfp8_moe_backend,
) )
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization.online.fp8 import (
LinearBase, _Fp8OnlineLinearBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.fp8 import ( from vllm.model_executor.layers.quantization.online.moe_base import (
Fp8Config, OnlineMoEMethodBase,
Fp8KVCacheMethod,
Fp8OnlineLinearMethod,
Fp8OnlineMoEMethod,
) )
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE, MXFP8_BLOCK_SIZE,
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.utils import replace_parameter from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__)
class Mxfp8Config(Fp8Config): class Mxfp8OnlineLinearMethod(_Fp8OnlineLinearBase):
"""Config class for online MXFP8 MoE quantization."""
def __init__(
self,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
) -> None:
if activation_scheme != "dynamic":
raise ValueError("mxfp8 only supports dynamic activation scheme.")
super().__init__(
is_checkpoint_fp8_serialized=False,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=None,
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "mxfp8"
@classmethod
def get_min_capability(cls) -> int:
# Marlin kernel supports MXFP8 on SM80+
return 80
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Mxfp8Config":
activation_scheme = cls.get_from_keys_or(
config, ["activation_scheme"], "dynamic"
)
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedLinearMethod()
return Mxfp8OnlineLinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Mxfp8OnlineMoEMethod(self, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
"""Online MXFP8 linear method. """Online MXFP8 linear method.
Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling Loads bf16/fp16 checkpoints and quantizes weights to MXFP8 (microscaling
FP8 with block-32 scales) during weight loading. FP8 with block-32 scales) during weight loading.
Args:
quant_config: The MXFP8 quantization config.
""" """
uses_meta_device: bool = True def __init__(self):
super().__init__()
def __init__(self, quant_config: "Mxfp8Config"):
self.quant_config = quant_config
self.kernel = init_mxfp8_linear_kernel() self.kernel = init_mxfp8_linear_kernel()
def create_weights( def create_weights(
...@@ -178,19 +94,15 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): ...@@ -178,19 +94,15 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
return self.kernel.apply_weights(layer, x, bias) return self.kernel.apply_weights(layer, x, bias)
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): class Mxfp8OnlineMoEMethod(OnlineMoEMethodBase):
"""MoE method for online MXFP8 (block) quantization.""" """MoE method for online MXFP8 (block) quantization."""
uses_meta_device: bool = True fp8_backend: "Fp8MoeBackend"
experts_cls: "type[mk.FusedMoEExperts] | None"
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
FusedMoEMethodBase.__init__(self, layer.moe_config)
self.quant_config = quant_config
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
self.weight_block_size = [1, MXFP8_BLOCK_SIZE] def __init__(self, *, layer: torch.nn.Module):
self.block_quant = True super().__init__(layer.moe_config)
self.weight_block_size: list[int] = [1, MXFP8_BLOCK_SIZE]
self.weight_scale_name = "weight_scale" self.weight_scale_name = "weight_scale"
self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe) self.fp8_backend, self.experts_cls = select_mxfp8_moe_backend(config=self.moe)
...@@ -247,6 +159,74 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): ...@@ -247,6 +159,74 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
return w_quant, w_scales return w_quant, w_scales
def _setup_kernel(
self,
layer: "FusedMoE",
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
)
# Shuffle weights to runtime format.
w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
fp8_backend=self.fp8_backend,
layer=layer,
w13=w13,
w2=w2,
w13_scale=w13_scale,
w2_scale=w2_scale,
w13_input_scale=w13_input_scale,
w2_input_scale=w2_input_scale,
)
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig":
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
make_fp8_moe_quant_config,
)
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
quant_config = make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=self.weight_block_size,
)
self._maybe_inject_biases(quant_config, layer)
return quant_config
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
......
...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( ...@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
) )
from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.mxfp8 import Mxfp8OnlineLinearMethod from vllm.model_executor.layers.quantization.online.mxfp8 import Mxfp8OnlineLinearMethod
from vllm.tracing import instrument from vllm.tracing import instrument
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt, fp8_gemm_nt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment