Commit 36c58b10 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-ds' into 'v0.11.0-dev'

fix: 修复deepseek量化模型的若干问题

See merge request dcutoolkit/deeplearing/vllm!338
parents 734f52d8 d10b80ce
...@@ -287,12 +287,16 @@ class FusedMoEQuantConfig: ...@@ -287,12 +287,16 @@ class FusedMoEQuantConfig:
@property @property
def use_int8_w8a8(self) -> bool: def use_int8_w8a8(self) -> bool:
return self.quant_dtype == torch.int8 return self.quant_dtype == torch.int8 and self._w1.dtype == torch.int8
@property @property
def use_int8_w8a16(self) -> bool: def use_int8_w8a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == torch.int8) return (self._a1.dtype is None and self._w1.dtype == torch.int8)
@property
def use_int4_w4a8(self) -> bool:
return (self._a1.dtype == torch.int8 and self._w1.dtype == "int4")
@property @property
def use_int4_w4a16(self) -> bool: def use_int4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "int4") return (self._a1.dtype is None and self._w1.dtype == "int4")
......
...@@ -1813,7 +1813,6 @@ def fused_experts( ...@@ -1813,7 +1813,6 @@ def fused_experts(
quant_config: Optional[FusedMoEQuantConfig] = None, quant_config: Optional[FusedMoEQuantConfig] = None,
allow_deep_gemm: bool = False, allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_int4_w4a8: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1873,7 +1872,7 @@ def fused_experts( ...@@ -1873,7 +1872,7 @@ def fused_experts(
use_int8_w8a8=quant_config.use_int8_w8a8, use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16, use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16, use_int4_w4a16=quant_config.use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8, use_int4_w4a8=quant_config.use_int4_w4a8,
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4, use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
per_channel_quant=quant_config.per_act_token_quant, per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -2370,7 +2369,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2370,7 +2369,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16, use_int4_w4a16=self.quant_config.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8, use_int4_w4a8=self.quant_config.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape, block_shape=self.block_shape,
B_bias=self.w1_bias, B_bias=self.w1_bias,
...@@ -2404,7 +2403,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2404,7 +2403,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.quant_config.use_int8_w8a8, use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16, use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16, use_int4_w4a16=self.quant_config.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8, use_int4_w4a8=self.quant_config.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape, block_shape=self.block_shape,
B_bias=self.w2_bias, B_bias=self.w2_bias,
......
...@@ -49,6 +49,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -49,6 +49,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list: list[str], sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None, kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
transform_config: Optional[dict[str, Any]] = None,
): ):
super().__init__( super().__init__(
target_scheme_map, target_scheme_map,
...@@ -57,7 +58,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -57,7 +58,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_scheme_map, sparsity_scheme_map,
sparsity_ignore_list, sparsity_ignore_list,
kv_cache_scheme, kv_cache_scheme,
config config,
transform_config
) )
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
......
...@@ -10,10 +10,11 @@ from vllm.logger import init_logger ...@@ -10,10 +10,11 @@ from vllm.logger import init_logger
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported, FusedMoEConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import( from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights) get_w8a8_int8_marlin_weights)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig)
try: try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception: except Exception:
...@@ -27,6 +28,9 @@ __all__ = [ ...@@ -27,6 +28,9 @@ __all__ = [
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
def __init_(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501 quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
...@@ -38,17 +42,17 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): ...@@ -38,17 +42,17 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
input_quant = quant_config.target_scheme_map["Linear"].get( input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config) return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config, layer.moe_config)
else: else:
raise RuntimeError( raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}") f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__( def __init__(self,
self, quant_config: "CompressedTensorsConfig",
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501 moe: FusedMoEConfig):
):
self.quant_config = quant_config self.quant_config = quant_config
super().__init__(moe)
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights") "weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
...@@ -69,7 +73,10 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -69,7 +73,10 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"For INT8 Fused MoE layers, we require channelwise, " "For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.") "dynamic per token quantization. Found static input scales.")
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -171,7 +178,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -171,7 +178,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"`CompressedTensorsW8A8Int8MoEMethod` yet.") "`CompressedTensorsW8A8Int8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
......
...@@ -13,11 +13,13 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -13,11 +13,13 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, ModelWeightParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from lmslim.layers.gemm.int8_utils import ( from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8) per_token_quant_int8)
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
...@@ -79,7 +81,7 @@ class SlimQuantW4A8Int8Config(QuantizationConfig): ...@@ -79,7 +81,7 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self) return SlimQuantW4A8Int8MoEMethod(self,layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -92,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -92,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
k=layer.weight.shape[1] k=layer.weight.shape[1]
...@@ -254,9 +256,29 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -254,9 +256,29 @@ class SlimQuantW4A8Int8MoEMethod:
return obj return obj
return super().__new__(cls) return super().__new__(cls)
def __init__(self, quant_config): def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
self.moe_quant_config = FusedMoEQuantConfig.make(
torch.int8,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=None,
)
self.moe_quant_config._w1.dtype="int4"
self.moe_quant_config._w1.dtype="int4"
return self.moe_quant_config
def create_weights( def create_weights(
self, self,
...@@ -365,7 +387,7 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -365,7 +387,7 @@ class SlimQuantW4A8Int8MoEMethod:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
...@@ -387,15 +409,10 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -387,15 +409,10 @@ class SlimQuantW4A8Int8MoEMethod:
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation, activation=activation,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale), quant_config=self.moe_quant_config,
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
) )
...@@ -13,8 +13,9 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon ...@@ -13,8 +13,9 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
ModelWeightParameter) from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
try: try:
...@@ -110,7 +111,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -110,7 +111,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self) return SlimQuantW4A8Int8MarlinMoEMethod(self,layer.moe_config)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -142,8 +143,16 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -142,8 +143,16 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
return obj return obj
return super().__new__(cls) return super().__new__(cls)
def __init__(self, quant_config): def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config self.quant_config = quant_config
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) :
return None
def create_weights( def create_weights(
self, self,
...@@ -235,7 +244,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -235,7 +244,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment