Unverified Commit 83f2d9d4 authored by PGFLMG's avatar PGFLMG Committed by GitHub
Browse files

[QuickFix] fix gptq model initialize (#6429)

parent 6317c5c6
...@@ -25,7 +25,6 @@ try: ...@@ -25,7 +25,6 @@ try:
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
) )
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config, GPTQMarlin24Config,
...@@ -58,7 +57,11 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -58,7 +57,11 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
CompressedTensorsConfig, CompressedTensorsConfig,
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig from sglang.srt.layers.quantization.gptq import (
GPTQConfig,
GPTQMarlinConfig,
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config, ModelOptFp4Config,
ModelOptFp8Config, ModelOptFp8Config,
......
import logging import logging
from fractions import Fraction from fractions import Fraction
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase, set_weight_attrs
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import replace_parameter
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
try: try:
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, marlin_moe_permute_scales,
) )
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...@@ -27,7 +34,9 @@ try: ...@@ -27,7 +34,9 @@ try:
except ImportError: except ImportError:
VLLM_AVAILABLE = False VLLM_AVAILABLE = False
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any GPTQLinearMethod = MarlinLinearMethod = Any
FusedMoEMethodBase = QuantizeMethodBase
class scalar_types: class scalar_types:
uint4b8 = "uint4b8" uint4b8 = "uint4b8"
...@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig): ...@@ -437,3 +446,286 @@ class MarlinConfig(QuantizationConfig):
): ):
return MarlinLinearMethod(self) return MarlinLinearMethod(self)
return None return None
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
intermediate_size = extra_weight_attrs.pop("intermediate_size")
self.is_k_full = (not self.quant_config.desc_act) or (
intermediate_size_per_partition == intermediate_size
)
if self.quant_config.group_size != -1:
scales_size13 = hidden_size // self.quant_config.group_size
w2_scales_size = (
intermediate_size
if self.quant_config.desc_act
else intermediate_size_per_partition
)
scales_size2 = w2_scales_size // self.quant_config.group_size
strategy = FusedMoeWeightScaleSupported.GROUP.value
else:
scales_size13 = 1
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size // self.quant_config.pack_factor,
2 * intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition // self.quant_config.pack_factor,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
# up_proj scales
w13_scales = torch.nn.Parameter(
torch.empty(
num_experts,
scales_size13,
2 * intermediate_size_per_partition,
dtype=torch.half,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
# down_proj scales
w2_scales = torch.nn.Parameter(
torch.empty(num_experts, scales_size2, hidden_size, dtype=torch.half),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_scales, {"load_full_w2": self.quant_config.desc_act})
# up_proj scales
w13_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
# down_proj scales
w2_qzeros = torch.nn.Parameter(
torch.empty(
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
# dont shard the w2 scales when running act order
set_weight_attrs(w2_qzeros, {"load_full_w2": self.quant_config.desc_act})
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
num_experts = layer.w13_g_idx.shape[0]
w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
torch.int32
)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
torch.int32
)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
# Repack weights
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1]
* (
self.quant_config.group_size
if self.quant_config.group_size != -1
else self.quant_config.pack_factor
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
# The input must currently be float16
orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops.vllm.fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.quant_type.size_bits,
is_k_full=self.is_k_full,
).to(orig_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment