Unverified Commit ad2b1277 authored by Asaf Gardin's avatar Asaf Gardin Committed by GitHub
Browse files

[Quantization] Consolidate experts_int8 with fp8 online quantization (#38463)


Signed-off-by: default avatarJosephasafg <ajgard7@gmail.com>
parent b897f00c
...@@ -38,6 +38,5 @@ def test_model_experts_int8_startup( ...@@ -38,6 +38,5 @@ def test_model_experts_int8_startup(
dtype=dtype, dtype=dtype,
enforce_eager=True, enforce_eager=True,
quantization="experts_int8", quantization="experts_int8",
allow_deprecated_quantization=True,
) as vllm_model: ) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens) vllm_model.generate_greedy(example_prompts, max_tokens)
...@@ -19,6 +19,10 @@ class OnlineQuantScheme(Enum): ...@@ -19,6 +19,10 @@ class OnlineQuantScheme(Enum):
# blocks of 128x128 elements (popularized by DeepSeek) # blocks of 128x128 elements (popularized by DeepSeek)
FP8_PER_BLOCK = "fp8_per_block" FP8_PER_BLOCK = "fp8_per_block"
# int8, weight-only per-channel quantization for MoE expert weights.
# Linear layers remain unquantized.
INT8_PER_CHANNEL_WEIGHT_ONLY = "int8_per_channel_weight_only"
# TODO(future PRs): add more online quant schemes here: mxfp8, etc # TODO(future PRs): add more online quant schemes here: mxfp8, etc
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.runner.shared_experts import (
SharedExperts,
)
logger = init_logger(__name__)
def select_int8_moe_backend(
config: FusedMoEConfig,
) -> type[mk.FusedMoEExperts]:
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
supported, reason = TritonExperts.is_supported_config(
TritonExperts,
config,
None,
None,
mk.FusedMoEActivationFormat.Standard,
)
if not supported:
raise ValueError(
f"INT8 Triton MoE backend does not support the "
f"deployment configuration: {reason}"
)
logger.info_once("Using Triton INT8 MoE backend", scope="local")
return TritonExperts
def make_int8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
return int8_w8a16_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
)
def make_int8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: SharedExperts | None = None,
) -> mk.FusedMoEKernel:
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__, scope="local")
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
return mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=shared_experts,
inplace=not moe_config.disable_inplace,
)
...@@ -40,6 +40,7 @@ QuantizationMethods = Literal[ ...@@ -40,6 +40,7 @@ QuantizationMethods = Literal[
# shorthand for creating a more complicated online quant config object # shorthand for creating a more complicated online quant config object
"fp8_per_tensor", "fp8_per_tensor",
"fp8_per_block", "fp8_per_block",
"int8_per_channel_weight_only",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -47,7 +48,6 @@ DEPRECATED_QUANTIZATION_METHODS = [ ...@@ -47,7 +48,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
"tpu_int8", "tpu_int8",
"fbgemm_fp8", "fbgemm_fp8",
"fp_quant", "fp_quant",
"experts_int8",
] ]
# The customized quantization methods which will be added to this dict. # The customized quantization methods which will be added to this dict.
......
...@@ -5,27 +5,25 @@ from typing import Any ...@@ -5,27 +5,25 @@ from typing import Any
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.layers.quantization.online.int8 import (
Int8OnlineMoEMethod,
)
class ExpertsInt8Config(QuantizationConfig): class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization.""" """Online int8 quantization for MoE expert weights.
Linear layers are left unquantized.
Backward-compatible config for ``--quantization experts_int8``.
Prefer ``--quantization int8_per_channel``
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -56,149 +54,5 @@ class ExpertsInt8Config(QuantizationConfig): ...@@ -56,149 +54,5 @@ class ExpertsInt8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self, layer.moe_config) return Int8OnlineMoEMethod(layer=layer)
return None return None
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(
self,
quant_config: ExpertsInt8Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
int8_dtype = torch.int8
assert "weight_loader" in extra_weight_attrs
weight_loader = extra_weight_attrs["weight_loader"]
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
layer, weight_loader
)
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=int8_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=int8_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_scale = torch.nn.Parameter(
torch.zeros(
num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32
),
requires_grad=False,
)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(
torch.zeros(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_scale", w2_scale)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return int8_w8a16_moe_quant_config(
w1_scale=layer.w13_scale, w2_scale=layer.w2_scale, w1_zp=None, w2_zp=None
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
def quantize_and_call_weight_loader(
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: int,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
shard_size = layer.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
device = get_tp_group().device
loaded_weight = loaded_weight.to(device)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == "w1":
scales = quantize_in_place_and_get_scales(loaded_weight[shard, :])
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, 0])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == "w3":
scales = quantize_in_place_and_get_scales(loaded_weight[shard, :])
layer.w13_scale.data[expert_id, shard_size : 2 * shard_size].copy_(
scales[:, 0]
)
# w2, down_proj case: Load into only shard of w2.
elif shard_id == "w2":
scales = quantize_in_place_and_get_scales(loaded_weight[:, shard])
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
else:
raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}")
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
return quantize_and_call_weight_loader
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
vmax = torch.iinfo(torch.int8).max
scales = torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax
weight.div_(scales)
weight.round_()
weight.clamp_(-vmax, vmax)
return scales
...@@ -9,6 +9,7 @@ from vllm.config.quantization import ( ...@@ -9,6 +9,7 @@ from vllm.config.quantization import (
OnlineQuantizationConfigArgs, OnlineQuantizationConfigArgs,
OnlineQuantScheme, OnlineQuantScheme,
) )
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
) )
...@@ -33,6 +34,11 @@ from vllm.model_executor.layers.quantization.online.fp8 import ( ...@@ -33,6 +34,11 @@ from vllm.model_executor.layers.quantization.online.fp8 import (
Fp8PerTensorOnlineLinearMethod, Fp8PerTensorOnlineLinearMethod,
Fp8PerTensorOnlineMoEMethod, Fp8PerTensorOnlineMoEMethod,
) )
from vllm.model_executor.layers.quantization.online.int8 import (
Int8OnlineMoEMethod,
)
logger = init_logger(__name__)
class OnlineQuantizationConfig(QuantizationConfig): class OnlineQuantizationConfig(QuantizationConfig):
...@@ -96,7 +102,13 @@ class OnlineQuantizationConfig(QuantizationConfig): ...@@ -96,7 +102,13 @@ class OnlineQuantizationConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
linear_scheme = self.args.linear_scheme_override or self.args.global_scheme linear_scheme = self.args.linear_scheme_override or self.args.global_scheme
if linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK: if linear_scheme == OnlineQuantScheme.INT8_PER_CHANNEL_WEIGHT_ONLY:
logger.warning_once(
"INT8 online quantization only quantizes MoE expert "
"weights. linear layers remain in full precision."
)
return UnquantizedLinearMethod()
elif linear_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineLinearMethod() return Fp8PerBlockOnlineLinearMethod()
else: else:
return Fp8PerTensorOnlineLinearMethod() return Fp8PerTensorOnlineLinearMethod()
...@@ -109,7 +121,9 @@ class OnlineQuantizationConfig(QuantizationConfig): ...@@ -109,7 +121,9 @@ class OnlineQuantizationConfig(QuantizationConfig):
return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedFusedMoEMethod(layer.moe_config)
moe_scheme = self.args.moe_scheme_override or self.args.global_scheme moe_scheme = self.args.moe_scheme_override or self.args.global_scheme
if moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK: if moe_scheme == OnlineQuantScheme.INT8_PER_CHANNEL_WEIGHT_ONLY:
return Int8OnlineMoEMethod(layer=layer)
elif moe_scheme == OnlineQuantScheme.FP8_PER_BLOCK:
return Fp8PerBlockOnlineMoEMethod(layer=layer) return Fp8PerBlockOnlineMoEMethod(layer=layer)
else: else:
return Fp8PerTensorOnlineMoEMethod(layer=layer) return Fp8PerTensorOnlineMoEMethod(layer=layer)
......
...@@ -10,7 +10,6 @@ if TYPE_CHECKING: ...@@ -10,7 +10,6 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend from vllm.model_executor.layers.fused_moe.oracle.fp8 import Fp8MoeBackend
...@@ -19,15 +18,15 @@ import vllm.envs as envs ...@@ -19,15 +18,15 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
select_fp8_moe_backend, select_fp8_moe_backend,
) )
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearMethodBase, LinearMethodBase,
) )
from vllm.model_executor.layers.quantization.online.moe_base import (
OnlineMoEMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
create_fp8_quant_key, create_fp8_quant_key,
...@@ -44,7 +43,7 @@ from vllm.model_executor.model_loader.reload.layerwise import ( ...@@ -44,7 +43,7 @@ from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing, initialize_online_processing,
) )
from vllm.model_executor.parameter import ModelWeightParameter from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
...@@ -268,21 +267,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): ...@@ -268,21 +267,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class _Fp8OnlineMoEBase(FusedMoEMethodBase): class _Fp8OnlineMoEBase(OnlineMoEMethodBase):
"""Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint """Shared base for online FP8 MoE methods. Loads fp16/bf16 checkpoint
weights onto meta device and materializes them just-in-time.""" weights onto meta device and materializes them just-in-time."""
uses_meta_device: bool = True
# Declared here for mypy; actual values are set in __init__. # Declared here for mypy; actual values are set in __init__.
fp8_backend: "Fp8MoeBackend" fp8_backend: "Fp8MoeBackend"
experts_cls: "type[mk.FusedMoEExperts] | None" experts_cls: "type[mk.FusedMoEExperts] | None"
weight_scale_name: str weight_scale_name: str
weight_block_size: list[int] | None weight_block_size: list[int] | None
moe: "FusedMoEConfig"
is_monolithic: bool
moe_quant_config: "FusedMoEQuantConfig | None"
moe_kernel: "mk.FusedMoEKernel | None"
def __init__( def __init__(
self, self,
...@@ -313,77 +306,6 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase): ...@@ -313,77 +306,6 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase):
allow_vllm_cutlass=False, allow_vllm_cutlass=False,
) )
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
device="meta", # materialized and processed during loading
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
initialize_online_processing(layer)
def _setup_kernel( def _setup_kernel(
self, self,
layer: "FusedMoE", layer: "FusedMoE",
...@@ -430,15 +352,6 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase): ...@@ -430,15 +352,6 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase):
shared_experts=layer.shared_experts, shared_experts=layer.shared_experts,
) )
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> "mk.FusedMoEPrepareAndFinalizeModular | None":
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig": ) -> "FusedMoEQuantConfig":
...@@ -460,68 +373,9 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase): ...@@ -460,68 +373,9 @@ class _Fp8OnlineMoEBase(FusedMoEMethodBase):
block_shape=self.weight_block_size, block_shape=self.weight_block_size,
) )
# Inject biases into the quant config if the model has them self._maybe_inject_biases(quant_config, layer)
# (e.g. GPT-OSS biased MoE)
if quant_config is not None and self.moe.has_bias:
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if w13_bias is not None:
quant_config._w1.bias = w13_bias
if w2_bias is not None:
quant_config._w2.bias = w2_bias
return quant_config return quant_config
@property
def supports_eplb(self) -> bool:
return True
def apply_monolithic(
self,
layer: "FusedMoE",
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: "FusedMoE",
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase): class Fp8PerTensorOnlineMoEMethod(_Fp8OnlineMoEBase):
"""Online tensorwise FP8 MoE quantization. """Online tensorwise FP8 MoE quantization.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
from torch.nn import Module
if TYPE_CHECKING:
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.oracle.int8 import (
make_int8_moe_kernel,
make_int8_moe_quant_config,
select_int8_moe_backend,
)
from vllm.model_executor.layers.quantization.online.moe_base import (
OnlineMoEMethodBase,
)
from vllm.model_executor.utils import replace_parameter
class Int8OnlineMoEMethod(OnlineMoEMethodBase):
"""Online per-channel INT8 MoE quantization.
Loads fp16/bf16 weights and quantizes them per-row to int8 during loading.
"""
def __init__(
self,
*,
layer: torch.nn.Module,
):
super().__init__(layer.moe_config)
self.experts_cls: type[mk.FusedMoEExperts] = select_int8_moe_backend(
config=self.moe,
)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
self._quantize_weights(layer)
self._setup_kernel(layer)
layer._already_called_process_weights_after_loading = True
def _quantize_weights(self, layer: Module) -> None:
vmax = torch.iinfo(torch.int8).max
w13 = torch.empty_like(layer.w13_weight, dtype=torch.int8)
w2 = torch.empty_like(layer.w2_weight, dtype=torch.int8)
w13_scale = torch.zeros(
layer.num_experts,
layer.w13_weight.shape[1],
device=w13.device,
dtype=torch.float32,
)
w2_scale = torch.zeros(
layer.num_experts,
layer.w2_weight.shape[1],
device=w2.device,
dtype=torch.float32,
)
for expert in range(layer.local_num_experts):
# w13: per-row quantization over hidden_size dim
w = layer.w13_weight[expert, :, :]
scales = w.abs().amax(dim=1) / vmax
q = w.div(scales.unsqueeze(1)).round().clamp(-vmax, vmax)
w13[expert, :, :] = q.to(torch.int8)
w13_scale[expert, :] = scales
# w2: per-row quantization over intermediate_size dim
w = layer.w2_weight[expert, :, :]
scales = w.abs().amax(dim=1) / vmax
q = w.div(scales.unsqueeze(1)).round().clamp(-vmax, vmax)
w2[expert, :, :] = q.to(torch.int8)
w2_scale[expert, :] = scales
replace_parameter(layer, "w13_weight", w13)
replace_parameter(layer, "w2_weight", w2)
replace_parameter(layer, "w13_scale", w13_scale)
replace_parameter(layer, "w2_scale", w2_scale)
def _setup_kernel(self, layer: "FusedMoE") -> None:
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
assert self.experts_cls is not None
self.moe_kernel = make_int8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> "FusedMoEQuantConfig | None":
quant_config = make_int8_moe_quant_config(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
)
self._maybe_inject_biases(quant_config, layer)
return quant_config
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.model_loader.reload.layerwise import (
initialize_online_processing,
)
from vllm.model_executor.utils import set_weight_attrs
class OnlineMoEMethodBase(FusedMoEMethodBase):
"""Base for MoE methods that load full-precision weights on meta device
and quantize them after loading via the QeRL layerwise processing system.
"""
uses_meta_device: bool = True
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# Fused gate_up_proj (column parallel) — full precision on meta device
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel) — full precision on meta device
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# BIASES (for models like GPT-OSS that have biased MoE)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="meta",
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
device="meta",
dtype=layer.orig_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
initialize_online_processing(layer)
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _maybe_inject_biases(
self,
quant_config: FusedMoEQuantConfig,
layer: torch.nn.Module,
) -> None:
"""Inject biases into the quant config if the model has them
(e.g. GPT-OSS biased MoE)."""
if self.moe.has_bias:
w13_bias = getattr(layer, "w13_bias", None)
w2_bias = getattr(layer, "w2_bias", None)
if w13_bias is not None:
quant_config._w1.bias = w13_bias
if w2_bias is not None:
quant_config._w2.bias = w2_bias
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
@property
def supports_eplb(self) -> bool:
return True
def apply_monolithic(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply_monolithic(
x,
layer.w13_weight,
layer.w2_weight,
router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
e_score_correction_bias=layer.e_score_correction_bias,
routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment