Commit 8b1e4ef0 authored by gaoqiong's avatar gaoqiong
Browse files

修改增加lmslimquant_w4a8量化支持

parent cc6f327a
...@@ -37,7 +37,7 @@ QuantizationMethods = Literal[ ...@@ -37,7 +37,7 @@ QuantizationMethods = Literal[
"auto-round", "auto-round",
"rtn", "rtn",
"blockwise_int8", "blockwise_int8",
"w8a8_int8" "slimquant_w4a8"
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -117,7 +117,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -117,7 +117,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config from .blockwise_int8 import BlockInt8Config
from .w8a8_int8 import W8A8Int8Config from .slimquant_w4a8 import SlimQuantW4A8Int8Config
method_to_config: dict[str, type[QuantizationConfig]] = { method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -150,7 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -150,7 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": AutoRoundConfig, "auto-round": AutoRoundConfig,
"rtn": RTNConfig, "rtn": RTNConfig,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"w8a8_int8":W8A8Int8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
...@@ -1000,6 +1000,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1000,6 +1000,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
raise ValueError( raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, " "For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.") "dynamic per token quantization. Found static input scales.")
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
...@@ -1089,6 +1091,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1089,6 +1091,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
...@@ -1111,6 +1116,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1111,6 +1116,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts( return fused_experts(
......
...@@ -40,7 +40,7 @@ def baseline_scaled_mm(a: torch.Tensor, ...@@ -40,7 +40,7 @@ def baseline_scaled_mm(a: torch.Tensor,
return output.to(out_dtype) return output.to(out_dtype)
class W8A8Int8Config(QuantizationConfig): class SlimQuantW4A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization. """Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric - Weight: static, per-channel, symmetric
...@@ -60,14 +60,14 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -60,14 +60,14 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
return "w8a8_int8" return "slimquant_w4a8"
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
return cls() return cls()
def get_quant_method( def get_quant_method(
...@@ -77,18 +77,18 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -77,18 +77,18 @@ class W8A8Int8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self) return SlimQuantW4A8Int8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
class W8A8Int8LinearMethod(LinearMethodBase): class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
...@@ -218,8 +218,8 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -218,8 +218,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
bias=bias) bias=bias)
class W8A8Int8MoEMethod: class SlimQuantW4A8Int8MoEMethod:
"""MoE method for INT8. """MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic Also supports loading quantized FP16/BF16 model checkpoints with dynamic
...@@ -354,7 +354,7 @@ class W8A8Int8MoEMethod: ...@@ -354,7 +354,7 @@ class W8A8Int8MoEMethod:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `W8A8Int8MoeMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -180,7 +180,7 @@ class RocmPlatform(Platform): ...@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8","awq_marlin" "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment