Commit 64a2aa19 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.11.0-dev

parents 2cbda743 36c58b10
......@@ -287,12 +287,16 @@ class FusedMoEQuantConfig:
@property
def use_int8_w8a8(self) -> bool:
return self.quant_dtype == torch.int8
return self.quant_dtype == torch.int8 and self._w1.dtype == torch.int8
@property
def use_int8_w8a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == torch.int8)
@property
def use_int4_w4a8(self) -> bool:
return (self._a1.dtype == torch.int8 and self._w1.dtype == "int4")
@property
def use_int4_w4a16(self) -> bool:
return (self._a1.dtype is None and self._w1.dtype == "int4")
......
......@@ -1813,7 +1813,6 @@ def fused_experts(
quant_config: Optional[FusedMoEQuantConfig] = None,
allow_deep_gemm: bool = False,
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_int4_w4a8: bool = False,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
......@@ -1873,7 +1872,7 @@ def fused_experts(
use_int8_w8a8=quant_config.use_int8_w8a8,
use_int8_w8a16=quant_config.use_int8_w8a16,
use_int4_w4a16=quant_config.use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_int4_w4a8=quant_config.use_int4_w4a8,
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
per_channel_quant=quant_config.per_act_token_quant,
global_num_experts=global_num_experts,
......@@ -2370,7 +2369,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8,
use_int4_w4a8=self.quant_config.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w1_bias,
......@@ -2404,7 +2403,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.quant_config.use_int8_w8a8,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
use_int4_w4a8=self.use_int4_w4a8,
use_int4_w4a8=self.quant_config.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=self.w2_bias,
......
......@@ -49,6 +49,7 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_ignore_list: list[str],
kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None,
transform_config: Optional[dict[str, Any]] = None,
):
super().__init__(
target_scheme_map,
......@@ -57,7 +58,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
sparsity_scheme_map,
sparsity_ignore_list,
kv_cache_scheme,
config
config,
transform_config
)
@classmethod
def override_quantization_method(
......
......@@ -10,10 +10,11 @@ from vllm.logger import init_logger
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
FusedMoeWeightScaleSupported, FusedMoEConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig)
try:
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
......@@ -27,6 +28,9 @@ __all__ = [
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
def __init_(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod
def get_moe_method(
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
......@@ -38,17 +42,17 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config, layer.moe_config)
else:
raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
def __init__(self,
quant_config: "CompressedTensorsConfig",
moe: FusedMoEConfig):
self.quant_config = quant_config
super().__init__(moe)
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
......@@ -69,7 +73,10 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
......@@ -171,7 +178,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
......@@ -13,11 +13,13 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
from vllm import envs
......@@ -79,7 +81,7 @@ class SlimQuantW4A8Int8Config(QuantizationConfig):
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MoEMethod(self)
return SlimQuantW4A8Int8MoEMethod(self,layer.moe_config)
return None
def get_scaled_act_names(self) -> List[str]:
......@@ -92,7 +94,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
......@@ -254,9 +256,29 @@ class SlimQuantW4A8Int8MoEMethod:
return obj
return super().__new__(cls)
def __init__(self, quant_config):
def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module)-> Optional[FusedMoEQuantConfig]:
self.moe_quant_config = FusedMoEQuantConfig.make(
torch.int8,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=None,
)
self.moe_quant_config._w1.dtype="int4"
self.moe_quant_config._w1.dtype="int4"
return self.moe_quant_config
def create_weights(
self,
......@@ -365,7 +387,7 @@ class SlimQuantW4A8Int8MoEMethod:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......@@ -387,15 +409,10 @@ class SlimQuantW4A8Int8MoEMethod:
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
quant_config=self.moe_quant_config,
use_nn_moe=use_nn_moe,
)
......@@ -13,8 +13,9 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
try:
......@@ -110,7 +111,7 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self)
return SlimQuantW4A8Int8MarlinMoEMethod(self,layer.moe_config)
return None
def get_scaled_act_names(self) -> List[str]:
......@@ -142,8 +143,16 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
return obj
return super().__new__(cls)
def __init__(self, quant_config):
def __init__(self, quant_config, moe):
self.moe = moe
self.quant_config = quant_config
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.fused_experts: Optional[FusedMoEModularKernel] = None
self.topk_indices_dtype = None
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) :
return None
def create_weights(
self,
......@@ -235,7 +244,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment