Unverified Commit e4ee48da authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[MoE refactor] refactor GPTQMarlinMoEMethod with MK (#37990)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: default avatarRobert Shaw <robertgshaw2@gmail.com>
parent 342c58bc
...@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -40,6 +40,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym, kFp8Static128BlockSym,
kFp8StaticChannelSym, kFp8StaticChannelSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kInt4Static,
kInt8Static,
kMxfp4Static, kMxfp4Static,
kMxfp8Static, kMxfp8Static,
kNvfp4Static, kNvfp4Static,
...@@ -585,6 +587,8 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular): ...@@ -585,6 +587,8 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
kMxfp4Static, kMxfp4Static,
kMxfp8Static, kMxfp8Static,
kNvfp4Static, kNvfp4Static,
kInt4Static,
kInt8Static,
] ]
return weight_key in SUPPORTED_W return weight_key in SUPPORTED_W
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
from typing import TYPE_CHECKING
import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_act_int8_process_scales,
marlin_moe_permute_scales,
marlin_permute_bias,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
logger = init_logger(__name__)
class WNA16MoEBackend(Enum):
MARLIN = "MARLIN"
BATCHED_MARLIN = "BATCHED_MARLIN"
def backend_to_kernel_cls(
backend: WNA16MoEBackend,
) -> list[type[mk.FusedMoEExperts]]:
"""Return the experts class for the given backend, or None for NONE."""
if backend == WNA16MoEBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
return [MarlinExperts]
elif backend == WNA16MoEBackend.BATCHED_MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
)
return [BatchedMarlinExperts]
else:
raise ValueError(f"Unknown WNA16 MoE backend: {backend.value}")
def _get_priority_backends() -> list[WNA16MoEBackend]:
"""
Get available backends in priority order based on platform and config.
"""
_AVAILABLE_BACKENDS = [
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
]
return _AVAILABLE_BACKENDS
def select_wna16_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey,
weight_bits: int,
) -> tuple[WNA16MoEBackend, type[mk.FusedMoEExperts]]:
"""Select the WNA16 MoE backend.
Args:
config: the shared ``FusedMoEConfig`` for this layer.
weight_bits: quantization bit-width (4 or 8). 8-bit weights are not
supported by the modular Marlin kernel, so ``NONE`` is returned.
Returns:
A tuple of (``WNA16MoEBackend``, experts class or ``None``).
"""
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: WNA16MoEBackend):
return f"Using '{backend.value}' WNA16 MoE backend."
def _make_log_unsupported(backend: WNA16MoEBackend, reason: str | None) -> str:
if reason:
return (
f"WNA16 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
return (
f"WNA16 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: WNA16MoEBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[WNA16MoEBackend, type[mk.FusedMoEExperts]]:
reason: str | None = None
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Select kernels in order of backend.
AVAILABLE_BACKENDS = _get_priority_backends()
for backend in AVAILABLE_BACKENDS:
activation_key = None # always BF16 activation for WNA16 MoE
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No WNA16 MoE backend supports the deployment configuration."
)
def make_wna16_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts] | None,
layer: torch.nn.Module,
is_k_full: bool,
w13_g_idx: torch.Tensor | None,
w2_g_idx: torch.Tensor | None,
w13_g_idx_sort_indices: torch.Tensor | None,
w2_g_idx_sort_indices: torch.Tensor | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
) -> mk.FusedMoEKernel:
# Currently, we only support MarlinExperts and BatchedMarlinExperts
assert experts_cls in (MarlinExperts, BatchedMarlinExperts)
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
)
assert prepare_finalize is not None
assert isinstance(prepare_finalize, mk.FusedMoEPrepareAndFinalizeModular)
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
assert experts_cls == BatchedMarlinExperts
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts: mk.FusedMoEExperts = BatchedMarlinExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=moe_config,
quant_config=moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=is_k_full,
)
else:
assert experts_cls == MarlinExperts
experts = MarlinExperts(
moe_config=moe_config,
quant_config=moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=is_k_full,
)
return mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=shared_experts,
inplace=not moe_config.disable_inplace,
)
# ---------------------------------------------------------------------------
# Per-backend weight post-processing
# ---------------------------------------------------------------------------
def _process_weights_marlin(
layer: torch.nn.Module,
quant_config: "GPTQMarlinConfig",
input_dtype: torch.dtype | None,
w13_qweight: torch.Tensor,
w2_qweight: torch.Tensor,
w13_scales: torch.Tensor,
w2_scales: torch.Tensor,
w13_g_idx: torch.Tensor,
w2_g_idx: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, # w13_qweight
torch.Tensor, # w2_qweight
torch.Tensor, # w13_scales
torch.Tensor, # w2_scales
torch.Tensor, # w13_g_idx
torch.Tensor, # w2_g_idx
torch.Tensor, # w13_g_idx_sort_indices
torch.Tensor, # w2_g_idx_sort_indices
torch.Tensor | None, # w13_input_global_scale
torch.Tensor | None, # w2_input_global_scale
torch.Tensor | None, # w13_bias
torch.Tensor | None, # w2_bias
]:
"""Standard Marlin weight post-processing shared by MARLIN and
BATCHED_MARLIN backends.
Steps
-----
1. Optional FP8 preprocessing of packed weights / scales.
2. Sort / reset g_idx tensors for act-order handling.
3. Repack weights via ``gptq_marlin_moe_repack``.
4. Permute scales (and optionally extract INT8 global scales).
5. Permute bias tensors.
"""
is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1
marlin_w13_qweight: torch.Tensor
marlin_w2_qweight: torch.Tensor
marlin_w13_scales: torch.Tensor
marlin_w2_scales: torch.Tensor
w13_g_idx_sort_indices: torch.Tensor | None = None
w2_g_idx_sort_indices: torch.Tensor | None = None
w13_input_global_scale: torch.Tensor | None = None
w2_input_global_scale: torch.Tensor | None = None
w13_bias_out: torch.Tensor | None = None
w2_bias_out: torch.Tensor | None = None
# --- FP8 weight / scale adjustment ---
if input_dtype == torch.float8_e4m3fn:
marlin_w13_qweight = ops.marlin_int4_fp8_preprocess(w13_qweight, inplace=False)
marlin_w2_qweight = ops.marlin_int4_fp8_preprocess(w2_qweight, inplace=False)
marlin_w13_scales = w13_scales.data * 512
marlin_w2_scales = w2_scales.data * 512
else:
marlin_w13_qweight = w13_qweight
marlin_w2_qweight = w2_qweight
marlin_w13_scales = w13_scales
marlin_w2_scales = w2_scales
# --- Process act_order (g_idx) ---
if quant_config.desc_act:
num_experts = w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(w2_g_idx)
w13_sorted_g_idx = torch.empty_like(w13_g_idx)
w2_sorted_g_idx = torch.empty_like(w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(w13_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(w2_g_idx[e]).to(torch.int32)
w13_sorted_g_idx[e] = w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = w2_g_idx[e][w2_g_idx_sort_indices[e]]
else:
num_experts = w13_g_idx.shape[0]
device = w13_g_idx.device
w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# --- Repack weights ---
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
marlin_w13_qweight,
w13_g_idx_sort_indices,
marlin_w13_qweight.shape[1] * quant_config.pack_factor,
marlin_w13_qweight.shape[2],
quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
marlin_w2_qweight,
w2_g_idx_sort_indices,
marlin_w2_qweight.shape[1] * quant_config.pack_factor,
marlin_w2_qweight.shape[2],
quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
# --- Permute scales ---
marlin_w13_scales = marlin_moe_permute_scales(
s=marlin_w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=marlin_w13_scales.shape[2],
group_size=quant_config.group_size,
is_a_8bit=is_a_8bit,
)
marlin_w2_scales = marlin_moe_permute_scales(
s=marlin_w2_scales,
size_k=marlin_w2_scales.shape[1]
* (
quant_config.group_size
if quant_config.group_size != -1
else quant_config.pack_factor
),
size_n=marlin_w2_scales.shape[2],
group_size=quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if input_dtype == torch.int8:
if layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
if layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
# --- Permute bias ---
if w13_bias is not None:
w13_bias_out = marlin_permute_bias(w13_bias)
if w2_bias is not None:
w2_bias_out = marlin_permute_bias(w2_bias)
return (
marlin_w13_qweight,
marlin_w2_qweight,
marlin_w13_scales,
marlin_w2_scales,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
w13_input_global_scale,
w2_input_global_scale,
w13_bias_out,
w2_bias_out,
)
def convert_to_wna16_moe_kernel_format(
backend: WNA16MoEBackend,
layer: torch.nn.Module,
quant_config: QuantizationConfig,
input_dtype: torch.dtype | None,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_g_idx: torch.Tensor,
w2_g_idx: torch.Tensor,
w13_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> tuple[
torch.Tensor, # w13_qweight
torch.Tensor, # w2_qweight
torch.Tensor, # w13_scales
torch.Tensor, # w2_scales
torch.Tensor | None, # w13_g_idx
torch.Tensor | None, # w2_g_idx
torch.Tensor | None, # w13_g_idx_sort_indices
torch.Tensor | None, # w2_g_idx_sort_indices
torch.Tensor | None, # w13_input_global_scale
torch.Tensor | None, # w2_input_global_scale
torch.Tensor | None, # w13_bias
torch.Tensor | None, # w2_bias
]:
"""Dispatch weight post-processing to the appropriate per-backend handler.
To add a new backend, implement a ``_process_weights_<name>`` helper and
add a branch here.
Args:
backend: the selected ``WNA16MoEBackend``.
layer: the ``FusedMoE`` layer whose parameters are being prepared.
quant_config: the ``QuantizationConfig`` for this layer.
input_dtype: optional activation dtype, usually should be 16 bit.
"""
if backend in (
WNA16MoEBackend.MARLIN,
WNA16MoEBackend.BATCHED_MARLIN,
):
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig,
)
if not isinstance(quant_config, GPTQMarlinConfig):
raise TypeError(
"Marlin WNA16 MoE backend requires GPTQMarlinConfig, got "
f"{type(quant_config).__name__}."
)
return _process_weights_marlin(
layer,
quant_config,
input_dtype,
w13,
w2,
w13_scale,
w2_scale,
w13_g_idx,
w2_g_idx,
w13_bias,
w2_bias,
)
else:
raise ValueError(f"Unsupported wna16 MoE backend: {backend.value}")
...@@ -9,7 +9,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE ...@@ -9,7 +9,6 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import ( from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig, MPLinearLayerConfig,
...@@ -19,13 +18,17 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -19,13 +18,17 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
) )
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoE,
FusedMoEMethodBase, FusedMoEMethodBase,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod,
) )
from vllm.model_executor.layers.fused_moe.oracle.int_wna16 import (
convert_to_wna16_moe_kernel_format,
make_wna16_moe_kernel,
select_wna16_moe_backend,
)
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -42,13 +45,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -42,13 +45,15 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, check_marlin_supported,
check_moe_marlin_supports_layer, check_moe_marlin_supports_layer,
get_marlin_input_dtype, get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new, marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
marlin_repeat_scales_on_all_ranks, marlin_repeat_scales_on_all_ranks,
verify_marlin_supported, verify_marlin_supported,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kInt4StaticGroupScale,
kInt8StaticGroupScale,
)
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
...@@ -500,13 +505,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -500,13 +505,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
super().__init__(moe) super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4: if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8 quant_type = scalar_types.uint4b8
scale = kInt4StaticGroupScale
elif self.quant_config.quant_type.size_bits == 8: elif self.quant_config.quant_type.size_bits == 8:
self.quant_type = scalar_types.uint8b128 quant_type = scalar_types.uint8b128
scale = kInt8StaticGroupScale
else: else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None self.input_dtype = None
self.use_marlin = True self.use_marlin = True
weight_key = QuantKey(quant_type, scale)
self.wna16_moe_backend, self.experts_cls = select_wna16_moe_backend(
moe, weight_key, quant_config.weight_bits
)
def create_weights( def create_weights(
self, self,
...@@ -521,7 +533,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -521,7 +533,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit: if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, ( assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel." "W8A8-INT8 is not supported by marlin kernel."
) )
...@@ -668,134 +680,100 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -668,134 +680,100 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit: if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, ( assert self.quant_config.quant_type.size_bits == 8, (
"W8A8-INT8 is not supported by marlin kernel." "W8A8-INT8 is not supported by marlin kernel."
) )
if self.input_dtype == torch.float8_e4m3fn: (
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True) w13,
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True) w2,
layer.w13_scales.data = layer.w13_scales.data * 512 w13_scale,
layer.w2_scales.data = layer.w2_scales.data * 512 w2_scale,
w13_g_idx,
# Process act_order w2_g_idx,
if self.quant_config.desc_act: w13_g_idx_sort_indices,
# Get sorting based on g_idx w2_g_idx_sort_indices,
num_experts = layer.w13_g_idx.shape[0] w13_input_global_scale,
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) w2_input_global_scale,
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) w13_bias,
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) w2_bias,
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) ) = convert_to_wna16_moe_kernel_format(
for e in range(num_experts): backend=self.wna16_moe_backend,
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( layer=layer,
torch.int32 quant_config=self.quant_config,
) input_dtype=self.input_dtype,
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( w13=layer.w13_qweight,
torch.int32 w2=layer.w2_qweight,
) w13_scale=layer.w13_scales,
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] w2_scale=layer.w2_scales,
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] w13_g_idx=layer.w13_g_idx,
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) w2_g_idx=layer.w2_g_idx,
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) w13_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
replace_parameter(layer, "w13_qweight", w13)
replace_parameter(layer, "w2_qweight", w2)
replace_parameter(layer, "w13_scales", w13_scale)
replace_parameter(layer, "w2_scales", w2_scale)
replace_parameter(layer, "w13_g_idx", w13_g_idx)
replace_parameter(layer, "w2_g_idx", w2_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else: if w13_input_global_scale is not None:
# Reset g_idx related tensors if hasattr(layer, "w13_input_global_scale"):
num_experts = layer.w13_g_idx.shape[0] replace_parameter(
device = layer.w13_g_idx.device layer, "w13_input_global_scale", w13_input_global_scale
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
) )
else:
layer.register_parameter( layer.register_parameter(
"w13_input_global_scale", "w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False), torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
) )
if w2_input_global_scale is not None:
replace_parameter(layer, "w13_scales", marlin_w13_scales) if hasattr(layer, "w2_input_global_scale"):
marlin_w2_scales = marlin_moe_permute_scales( replace_parameter(layer, "w2_input_global_scale", w2_input_global_scale)
s=layer.w2_scales, else:
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter( layer.register_parameter(
"w2_input_global_scale", "w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False), torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
) )
if w13_bias is not None:
if hasattr(layer, "w13_bias"):
replace_parameter(layer, "w13_bias", w13_bias)
else:
layer.register_parameter(
"w13_bias", torch.nn.Parameter(w13_bias, requires_grad=False)
)
if w2_bias is not None:
if hasattr(layer, "w2_bias"):
replace_parameter(layer, "w2_bias", w2_bias)
else:
layer.register_parameter(
"w2_bias", torch.nn.Parameter(w2_bias, requires_grad=False)
)
replace_parameter(layer, "w2_scales", marlin_w2_scales) self._setup_kernel(layer)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None: def _setup_kernel(self, layer: FusedMoE) -> None:
layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) """Build the FusedMoEKernel for this layer."""
if hasattr(layer, "w2_bias") and layer.w2_bias is not None: self.moe_quant_config = self.get_fused_moe_quant_config(layer)
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) self.moe_kernel = make_wna16_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
layer=layer,
is_k_full=self.is_k_full,
w13_g_idx=layer.w13_g_idx,
w2_g_idx=layer.w2_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
routing_tables=layer._maybe_init_expert_routing_tables(),
shared_experts=layer.shared_experts,
)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config, gptq_marlin_moe_quant_config,
) )
...@@ -820,84 +798,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -820,84 +798,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
prepare_finalize, prepare_finalize,
layer: torch.nn.Module, layer: torch.nn.Module,
): ):
""" raise ValueError(
Select the GEMM implementation for GPTQ-Marlin MoE. f"{self.__class__.__name__} uses the new modular kernel "
"initialization logic. This function should not be called."
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if self.quant_config.weight_bits == 8:
raise NotImplementedError(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = (
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
)
w2_g_idx = (
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
)
w13_g_idx_sort_indices = (
getattr(layer, "w13_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
w2_g_idx_sort_indices = (
getattr(layer, "w2_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
) )
def apply( def apply(
...@@ -908,28 +811,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -908,28 +811,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None, shared_experts_input: torch.Tensor | None,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_marlin_moe( assert not self.is_monolithic
x, assert self.moe_kernel is not None
layer.w13_qweight, return self.moe_kernel.apply(
layer.w2_qweight, hidden_states=x,
getattr(layer, "w13_bias", None), w1=layer.w13_qweight,
getattr(layer, "w2_bias", None), w2=layer.w2_qweight,
layer.w13_scales, topk_weights=topk_weights,
layer.w2_scales, topk_ids=topk_ids,
topk_weights, activation=layer.activation,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
expert_map=layer.expert_map, expert_map=layer.expert_map,
g_idx1=layer.w13_g_idx, shared_experts_input=shared_experts_input,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
inplace=not self.moe.disable_inplace,
) )
...@@ -20,6 +20,8 @@ if TYPE_CHECKING: ...@@ -20,6 +20,8 @@ if TYPE_CHECKING:
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
MXFP_SCALE_DTYPE = torch.uint8 MXFP_SCALE_DTYPE = torch.uint8
INT4_DTYPE = scalar_types.uint4b8
INT8_DTYPE = scalar_types.uint8b128
def get_fp8_min_max() -> tuple[float, float]: def get_fp8_min_max() -> tuple[float, float]:
...@@ -170,6 +172,12 @@ kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=Tru ...@@ -170,6 +172,12 @@ kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=Tru
kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32)) kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True) kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
# TODO: convert this to use SCALAR_TYPE. This is not right.
kInt4StaticGroupScale = ScaleDesc(torch.float16, True, GroupShape(1, -1))
kInt4Static = QuantKey(INT4_DTYPE, scale=kInt4StaticGroupScale, symmetric=True)
kInt8StaticGroupScale = ScaleDesc(torch.float16, True, GroupShape(1, -1))
kInt8Static = QuantKey(INT8_DTYPE, scale=kInt8StaticGroupScale, symmetric=True)
kInt8StaticChannelSym = QuantKey(torch.int8, kStaticChannelScale, symmetric=True) kInt8StaticChannelSym = QuantKey(torch.int8, kStaticChannelScale, symmetric=True)
kInt8DynamicTokenSym = QuantKey(torch.int8, kDynamicTokenScale, symmetric=True) kInt8DynamicTokenSym = QuantKey(torch.int8, kDynamicTokenScale, symmetric=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment