"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "ea9b8584bcd925f1fc7fc9dc2f480b0f5febfdb7"
Unverified Commit 6a6fc41c authored by Bhanu Prakash Voutharoja's avatar Bhanu Prakash Voutharoja Committed by GitHub
Browse files

gptq marlin quantization support for fused moe with lora (#30254)


Signed-off-by: default avatarBhanu068 <voutharoja.bhanu06@gmail.com>
parent f355ad54
...@@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm( ...@@ -860,4 +860,4 @@ torch::Tensor moe_wna16_marlin_gemm(
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm); m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
} }
\ No newline at end of file
...@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config( ...@@ -543,6 +543,42 @@ def int8_w8a8_moe_quant_config(
) )
def gptq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
weight_bits: int,
group_size: int,
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
):
"""
Construct a quant config for gptq marlin quantization.
"""
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
w_shape = None if group_size == -1 else GroupShape(row=1, col=group_size)
# Activations are NOT quantized for GPTQ (fp16/bf16)
a_shape = w_shape # Same as weight shape for alignment
# Determine weight dtype
if weight_bits == 4:
weight_dtype = "int4"
elif weight_bits == 8:
weight_dtype = torch.int8
else:
raise ValueError(f"Unsupported weight_bits: {weight_bits}")
return FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_a2=FusedMoEQuantDesc(dtype=None, shape=a_shape),
_w1=FusedMoEQuantDesc(weight_dtype, w_shape, w1_scale, None, w1_zp, w1_bias),
_w2=FusedMoEQuantDesc(weight_dtype, w_shape, w2_scale, None, w2_zp, w2_bias),
)
def mxfp4_w4a16_moe_quant_config( def mxfp4_w4a16_moe_quant_config(
w1_scale: Union[torch.Tensor, "PrecisionConfig"], w1_scale: Union[torch.Tensor, "PrecisionConfig"],
w2_scale: Union[torch.Tensor, "PrecisionConfig"], w2_scale: Union[torch.Tensor, "PrecisionConfig"],
......
...@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -732,6 +732,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
is_a_8bit=is_a_8bit, is_a_8bit=is_a_8bit,
) )
replace_parameter(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# The modular kernel expects w13_weight and w2_weight,
# but GPTQ uses w13_qweight and w2_qweight
# Alias for modular kernel
layer.w13_weight = layer.w13_qweight
# Alias for modular kernel
layer.w2_weight = layer.w2_qweight
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales, s=layer.w13_scales,
...@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -782,7 +790,107 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
return None from vllm.model_executor.layers.fused_moe.config import (
gptq_marlin_moe_quant_config,
)
return gptq_marlin_moe_quant_config(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
weight_bits=self.quant_config.weight_bits,
group_size=self.quant_config.group_size,
w1_zp=getattr(layer, "w13_qzeros", None)
if not self.quant_config.is_sym
else None,
w2_zp=getattr(layer, "w2_qzeros", None)
if not self.quant_config.is_sym
else None,
w1_bias=getattr(layer, "w13_bias", None),
w2_bias=getattr(layer, "w2_bias", None),
)
def select_gemm_impl(
self,
prepare_finalize,
layer: torch.nn.Module,
):
"""
Select the GEMM implementation for GPTQ-Marlin MoE.
Returns MarlinExperts configured for GPTQ quantization.
This is ONLY used when LoRA is enabled.
Without LoRA, GPTQ uses its own apply() method.
"""
# Only use modular kernels when LoRA is enabled
# Without LoRA, GPTQ's own apply() method works fine and is more efficient
if not self.moe.is_lora_enabled:
raise NotImplementedError(
"GPTQ-Marlin uses its own apply() method when LoRA is not enabled. "
"Modular kernels are only used for LoRA support."
)
# The modular marlin kernels do not support 8-bit weights.
if self.quant_config.weight_bits == 8:
raise NotImplementedError(
"GPTQ-Marlin kernel does not support 8-bit weights."
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
)
# Ensure quant config is initialized
assert self.moe_quant_config is not None, (
"moe_quant_config must be initialized before select_gemm_impl"
)
w13_g_idx = (
getattr(layer, "w13_g_idx", None) if self.quant_config.desc_act else None
)
w2_g_idx = (
getattr(layer, "w2_g_idx", None) if self.quant_config.desc_act else None
)
w13_g_idx_sort_indices = (
getattr(layer, "w13_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
w2_g_idx_sort_indices = (
getattr(layer, "w2_g_idx_sort_indices", None)
if self.quant_config.desc_act
else None
)
# Check if using batched expert format (for Expert Parallelism)
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
# For batched format, use BatchedMarlinExperts
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
# Standard Marlin experts for GPTQ
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=w13_g_idx,
w2_g_idx=w2_g_idx,
w13_g_idx_sort_indices=w13_g_idx_sort_indices,
w2_g_idx_sort_indices=w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
def apply( def apply(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment