Unverified Commit e4ee48da authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[MoE refactor] refactor GPTQMarlinMoEMethod with MK (#37990)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
parent 342c58bc
......@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
kInt4Static,
kInt8Static,
kMxfp4Static,
kMxfp8Static,
kNvfp4Static,
......@@ -585,6 +587,8 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
kMxfp4Static,
kMxfp8Static,
kNvfp4Static,
kInt4Static,
kInt8Static,
]
return weight_key in SUPPORTED_W
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING
import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_act_int8_process_scales,
marlin_moe_permute_scales,
marlin_permute_bias,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
logger = init_logger(__name__)
class WNA16MoEBackend(Enum):
MARLIN = "MARLIN"
BATCHED_MARLIN = "BATCHED_MARLIN"
def backend_to_kernel_cls(
backend: WNA16MoEBackend,
) -> list[type[mk.FusedMoEExperts]]:
"""Return the experts class for the given backend, or None for NONE."""
if backend == WNA16MoEBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return [MarlinExperts]
elif backend == WNA16MoEBackend.BATCHED_MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
)
return [BatchedMarlinExperts]
else:
raise ValueError(f"Unknown WNA16 MoE backend: {backend.value}")
def _get_priority_backends() -> list[WNA16MoEBackend]:
"""
Get available backends in priority order based on platform and config.
"""
_AVAILABLE_BACKENDS = [
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def select_wna16_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey,
weight_bits: int,
) -> tuple[WNA16MoEBackend, type[mk.FusedMoEExperts]]:
"""Select the WNA16 MoE backend.
Args:
config: the shared ``FusedMoEConfig`` for this layer.
weight_bits: quantization bit-width (4 or 8). 8-bit weights are not
supported by the modular Marlin kernel, so ``NONE`` is returned.
Returns:
A tuple of (``WNA16MoEBackend``, experts class or ``None``).
"""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: WNA16MoEBackend):
return f"Using '{backend.value}' WNA16 MoE backend."
def _make_log_unsupported(backend: WNA16MoEBackend, reason: str | None) -> str:
if reason:
return (
f"WNA16 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"WNA16 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: WNA16MoEBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[WNA16MoEBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends()
for backend in AVAILABLE_BACKENDS:
activation_key = None # always BF16 activation for WNA16 MoE
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No WNA16 MoE backend supports the deployment configuration."
)
def make_wna16_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts] | None,
layer: torch.nn.Module,
is_k_full: bool,
w13_g_idx: torch.Tensor | None,
w2_g_idx: torch.Tensor | None,
w13_g_idx_sort_indices: torch.Tensor | None,
w2_g_idx_sort_indices: torch.Tensor | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
# Currently, we only support MarlinExperts and BatchedMarlinExperts
assert experts_cls in (MarlinExperts, BatchedMarlinExperts)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
assert isinstance(prepare_finalize, mk.FusedMoEPrepareAndFinalizeModular)
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
assert experts_cls == BatchedMarlinExperts
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts: mk.FusedMoEExperts = BatchedMarlinExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=moe_config,
quant_config=moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=is_k_full,
)
else:
assert experts_cls == MarlinExperts
experts = MarlinExperts(
moe_config=moe_config,
quant_config=moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=is_k_full,
)
return mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=shared_experts,
inplace=not moe_config.disable_inplace,
)
# ---------------------------------------------------------------------------
# Per-backend weight post-processing
# ---------------------------------------------------------------------------
def _process_weights_marlin(
layer: torch.nn.Module,
quant_config: "GPTQMarlinConfig",
input_dtype: torch.dtype | None,
w13_qweight: torch.Tensor,
w2_qweight: torch.Tensor,
w13_scales: torch.Tensor,
w2_scales: torch.Tensor,
w13_g_idx: torch.Tensor,
w2_g_idx: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, # w13_qweight
torch.Tensor, # w2_qweight
torch.Tensor, # w13_scales
torch.Tensor, # w2_scales
torch.Tensor, # w13_g_idx
torch.Tensor, # w2_g_idx
torch.Tensor, # w13_g_idx_sort_indices
torch.Tensor, # w2_g_idx_sort_indices
torch.Tensor | None, # w13_input_global_scale
torch.Tensor | None, # w2_input_global_scale
torch.Tensor | None, # w13_bias
torch.Tensor | None, # w2_bias
]:
"""Standard Marlin weight post-processing shared by MARLIN and
BATCHED_MARLIN backends.
Steps
-----
1. Optional FP8 preprocessing of packed weights / scales.
2. Sort / reset g_idx tensors for act-order handling.
3. Repack weights via ``gptq_marlin_moe_repack``.
4. Permute scales (and optionally extract INT8 global scales).
5. Permute bias tensors.
"""
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
marlin_w13_qweight: torch.Tensor
marlin_w2_qweight: torch.Tensor
marlin_w13_scales: torch.Tensor
marlin_w2_scales: torch.Tensor
w13_g_idx_sort_indices: torch.Tensor | None = None
w2_g_idx_sort_indices: torch.Tensor | None = None
w13_input_global_scale: torch.Tensor | None = None
w2_input_global_scale: torch.Tensor | None = None
w13_bias_out: torch.Tensor | None = None
w2_bias_out: torch.Tensor | None = None
# --- FP8 weight / scale adjustment ---
if input_dtype == torch.float8_e4m3fn:
marlin_w13_qweight = ops.marlin_int4_fp8_preprocess(w13_qweight, inplace=False)
marlin_w2_qweight = ops.marlin_int4_fp8_preprocess(w2_qweight, inplace=False)
marlin_w13_scales = w13_scales.data * 512
marlin_w2_scales = w2_scales.data * 512
else:
marlin_w13_qweight = w13_qweight
marlin_w2_qweight = w2_qweight
marlin_w13_scales = w13_scales
marlin_w2_scales = w2_scales
# --- Process act_order (g_idx) ---
if quant_config.desc_act:
num_experts = w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(w2_g_idx)
w13_sorted_g_idx = torch.empty_like(w13_g_idx)
w2_sorted_g_idx = torch.empty_like(w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(w13_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(w2_g_idx[e]).to(torch.int32)
w13_sorted_g_idx[e] = w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = w2_g_idx[e][w2_g_idx_sort_indices[e]]
else:
num_experts = w13_g_idx.shape[0]
device = w13_g_idx.device
w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# --- Repack weights ---
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
marlin_w13_qweight,
w13_g_idx_sort_indices,
marlin_w13_qweight.shape[1] * quant_config.pack_factor,
marlin_w13_qweight.shape[2],
quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
marlin_w2_qweight,
w2_g_idx_sort_indices,
marlin_w2_qweight.shape[1] * quant_config.pack_factor,
marlin_w2_qweight.shape[2],
quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
# --- Permute scales ---
marlin_w13_scales = marlin_moe_permute_scales(
s=marlin_w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=marlin_w13_scales.shape[2],
group_size=quant_config.group_size,
is_a_8bit=is_a_8bit,
)
marlin_w2_scales = marlin_moe_permute_scales(
s=marlin_w2_scales,
size_k=marlin_w2_scales.shape[1]
* (
quant_config.group_size
if quant_config.group_size != -1
else quant_config.pack_factor
),
size_n=marlin_w2_scales.shape[2],
group_size=quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if input_dtype == torch.int8:
if layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
if layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
# --- Permute bias ---
if w13_bias is not None:
w13_bias_out = marlin_permute_bias(w13_bias)
if w2_bias is not None:
w2_bias_out = marlin_permute_bias(w2_bias)
return (
marlin_w13_qweight,
marlin_w2_qweight,
marlin_w13_scales,
marlin_w2_scales,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_input_global_scale,
w2_input_global_scale,
w13_bias_out,
w2_bias_out,
)
def convert_to_wna16_moe_kernel_format(
backend: WNA16MoEBackend,
layer: torch.nn.Module,
quant_config: QuantizationConfig,
input_dtype: torch.dtype | None,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_g_idx: torch.Tensor,
w2_g_idx: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, # w13_qweight
torch.Tensor, # w2_qweight
torch.Tensor, # w13_scales
torch.Tensor, # w2_scales
torch.Tensor | None, # w13_g_idx
torch.Tensor | None, # w2_g_idx
torch.Tensor | None, # w13_g_idx_sort_indices
torch.Tensor | None, # w2_g_idx_sort_indices
torch.Tensor | None, # w13_input_global_scale
torch.Tensor | None, # w2_input_global_scale
torch.Tensor | None, # w13_bias
torch.Tensor | None, # w2_bias
]:
"""Dispatch weight post-processing to the appropriate per-backend handler.
To add a new backend, implement a ``_process_weights_<name>`` helper and
add a branch here.
Args:
backend: the selected ``WNA16MoEBackend``.
layer: the ``FusedMoE`` layer whose parameters are being prepared.
quant_config: the ``QuantizationConfig`` for this layer.
input_dtype: optional activation dtype, usually should be 16 bit.
"""
if backend in (
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
):
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig,
)
if not isinstance(quant_config, GPTQMarlinConfig):
raise TypeError(
"Marlin WNA16 MoE backend requires GPTQMarlinConfig, got "
f"{type(quant_config).__name__}."
)
return _process_weights_marlin(
layer,
quant_config,
input_dtype,
w13,
w2,
w13_scale,
w2_scale,
w13_g_idx,
w2_g_idx,
w13_bias,
w2_bias,
)
else:
raise ValueError(f"Unsupported wna16 MoE backend: {backend.value}")
......@@ -9,7 +9,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig,
......@@ -19,13 +18,17 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.oracle.int_wna16 import (
convert_to_wna16_moe_kernel_format,
make_wna16_moe_kernel,
select_wna16_moe_backend,
)
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
......@@ -42,13 +45,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
marlin_repeat_scales_on_all_ranks,
verify_marlin_supported,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kInt4StaticGroupScale,
kInt8StaticGroupScale,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
......@@ -500,13 +505,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
quant_type = scalar_types.uint4b8
scale = kInt4StaticGroupScale
elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128
quant_type = scalar_types.uint8b128
scale = kInt8StaticGroupScale
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True
weight_key = QuantKey(quant_type, scale)
self.wna16_moe_backend, self.experts_cls = select_wna16_moe_backend(
moe, weight_key, quant_config.weight_bits
)
def create_weights(
self,
......@@ -521,7 +533,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel."
)
......@@ -668,134 +680,100 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
(
w13,
w2,
w13_scale,
w2_scale,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_input_global_scale,
w2_input_global_scale,
w13_bias,
w2_bias,
) = convert_to_wna16_moe_kernel_format(
backend=self.wna16_moe_backend,
layer=layer,
quant_config=self.quant_config,
input_dtype=self.input_dtype,
w13=layer.w13_qweight,
w2=layer.w2_qweight,
w13_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w13_g_idx=layer.w13_g_idx,
w2_g_idx=layer.w2_g_idx,
w13_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_qweight", w13)
replace_parameter(layer, "w2_qweight", w2)
replace_parameter(layer, "w13_scales", w13_scale)
replace_parameter(layer, "w2_scales", w2_scale)
replace_parameter(layer, "w13_g_idx", w13_g_idx)
replace_parameter(layer, "w2_g_idx", w2_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
if w13_input_global_scale is not None:
if hasattr(layer, "w13_input_global_scale"):
replace_parameter(
layer, "w13_input_global_scale", w13_input_global_scale
)
else:
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
if w2_input_global_scale is not None:
if hasattr(layer, "w2_input_global_scale"):
replace_parameter(layer, "w2_input_global_scale", w2_input_global_scale)
else:
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
if w13_bias is not None:
if hasattr(layer, "w13_bias"):
replace_parameter(layer, "w13_bias", w13_bias)
else:
layer.register_parameter(
"w13_bias", torch.nn.Parameter(w13_bias, requires_grad=False)
)
if w2_bias is not None:
if hasattr(layer, "w2_bias"):
replace_parameter(layer, "w2_bias", w2_bias)
else:
layer.register_parameter(
"w2_bias", torch.nn.Parameter(w2_bias, requires_grad=False)
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
self._setup_kernel(layer)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
def _setup_kernel(self, layer: FusedMoE) -> None:
"""Build the FusedMoEKernel for this layer."""
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
self.moe_kernel = make_wna16_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
layer=layer,
is_k_full=self.is_k_full,
w13_g_idx=layer.w13_g_idx,
w2_g_idx=layer.w2_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config,
)
......@@ -820,86 +798,11 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if self.quant_config.weight_bits == 8:
raise NotImplementedError(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = (
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
)
w2_g_idx = (
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
)
w13_g_idx_sort_indices = (
getattr(layer, "w13_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
w2_g_idx_sort_indices = (
getattr(layer, "w2_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply(
self,
layer: FusedMoE,
......@@ -908,28 +811,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor:
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
getattr(layer, "w13_bias", None),
getattr(layer, "w2_bias", None),
layer.w13_scales,
layer.w2_scales,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
assert not self.is_monolithic
assert self.moe_kernel is not None
return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_qweight,
w2=layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
shared_experts_input=shared_experts_input,
)
......@@ -20,6 +20,8 @@ if TYPE_CHECKING:
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
MXFP_SCALE_DTYPE = torch.uint8
INT4_DTYPE = scalar_types.uint4b8
INT8_DTYPE = scalar_types.uint8b128
def get_fp8_min_max() -> tuple[float, float]:
......@@ -170,6 +172,12 @@ kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=Tru
kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
# TODO: convert this to use SCALAR_TYPE. This is not right.
kInt4StaticGroupScale = ScaleDesc(torch.float16, True, GroupShape(1, -1))
kInt4Static = QuantKey(INT4_DTYPE, scale=kInt4StaticGroupScale, symmetric=True)
kInt8StaticGroupScale = ScaleDesc(torch.float16, True, GroupShape(1, -1))
kInt8Static = QuantKey(INT8_DTYPE, scale=kInt8StaticGroupScale, symmetric=True)
kInt8StaticChannelSym = QuantKey(torch.int8, kStaticChannelScale, symmetric=True)
kInt8DynamicTokenSym = QuantKey(torch.int8, kDynamicTokenScale, symmetric=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment