Unverified Commit 8616357a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix deepseek awq v3 (#3450)

parent 8adbc78b
...@@ -421,11 +421,18 @@ class ColumnParallelLinear(LinearBase): ...@@ -421,11 +421,18 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1 assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
from sglang.srt.layers.parameter import _ColumnvLLMParameter
if isinstance(param, _ColumnvLLMParameter):
# FIXME: why would we need this special case?
param.load_column_parallel_weight( param.load_column_parallel_weight(
loaded_weight, loaded_weight,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights, use_presharded_weights=self.use_presharded_weights,
) )
else:
param.load_column_parallel_weight(loaded_weight)
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
......
...@@ -298,7 +298,9 @@ class FusedMoE(torch.nn.Module): ...@@ -298,7 +298,9 @@ class FusedMoE(torch.nn.Module):
layer=self, layer=self,
num_experts=num_experts, num_experts=num_experts,
hidden_size=hidden_size, hidden_size=hidden_size,
# FIXME: figure out which intermediate_size to use
intermediate_size=self.intermediate_size_per_partition, intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader, weight_loader=self.weight_loader,
) )
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from typing import Callable, Dict, Optional, Type
from typing import Dict, Type import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsConfig,
...@@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix): ...@@ -73,21 +76,61 @@ def gptq_get_quant_method(self, layer, prefix):
def awq_get_quant_method(self, layer, prefix): def awq_get_quant_method(self, layer, prefix):
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq_marlin import ( from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod, AWQMarlinLinearMethod,
AWQMoEMethod, AWQMoEMethod,
) )
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQMarlinLinearMethod(self) return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self) return AWQMoEMethod(self)
return None return None
original_awq_moe_method_apply = AWQMoEMethod.apply
def awq_moe_method_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
**kwargs,
):
return original_awq_moe_method_apply(
self,
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
e_score_correction_bias,
)
def patch_vllm_linear_base_isinstance(): def patch_vllm_linear_base_isinstance():
import builtins import builtins
...@@ -107,8 +150,11 @@ def patch_vllm_linear_base_isinstance(): ...@@ -107,8 +150,11 @@ def patch_vllm_linear_base_isinstance():
def apply_monkey_patches(): def apply_monkey_patches():
"""Apply all monkey patches in one place.""" """Apply all monkey patches in one place."""
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
setattr(AWQMoEMethod, "apply", awq_moe_method_apply)
patch_vllm_linear_base_isinstance() patch_vllm_linear_base_isinstance()
......
...@@ -255,6 +255,8 @@ class DeepseekV2Attention(nn.Module): ...@@ -255,6 +255,8 @@ class DeepseekV2Attention(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
# FIXME: quick fix for skip quantization
prefix=f"self_attn.kv_a_proj_with_mqa",
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -455,6 +457,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -455,6 +457,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
# FIXME: quick fix for skip quantization
prefix=f"self_attn.kv_a_proj_with_mqa",
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment