Commit 54139f16 authored by zhuwenwen's avatar zhuwenwen
Browse files

修改增加lmslimquant_w4a8量化支持

parent bdda5719
......@@ -36,9 +36,10 @@ class ActivationMethod(IntEnum):
@cache
def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_MOE \
and envs.VLLM_ROCM_USE_AITER
return False
# return current_platform.is_rocm() \
# and envs.VLLM_ROCM_USE_AITER_MOE \
# and envs.VLLM_ROCM_USE_AITER
def rocm_aiter_asm_moe_tkw1_impl(
......
......@@ -38,7 +38,7 @@ QuantizationMethods = Literal[
"rtn",
"inc",
"blockwise_int8",
"w8a8_int8",
"slimquant_w4a8",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
......@@ -119,7 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config
from .w8a8_int8 import W8A8Int8Config
from .slimquant_w4a8 import SlimQuantW4A8Int8Config
method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
......@@ -153,7 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn": RTNConfig,
"inc": INCConfig,
"blockwise_int8": BlockInt8Config,
"w8a8_int8":W8A8Int8Config,
"slimquant_w4a8":SlimQuantW4A8Int8Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
......@@ -652,6 +652,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
......@@ -671,6 +674,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
......
......@@ -40,7 +40,7 @@ def baseline_scaled_mm(a: torch.Tensor,
return output.to(out_dtype)
class W8A8Int8Config(QuantizationConfig):
class SlimQuantW4A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
......@@ -60,14 +60,14 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod
def get_name(self) -> str:
return "w8a8_int8"
return "slimquant_w4a8"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
return cls()
def get_quant_method(
......@@ -77,18 +77,18 @@ class W8A8Int8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return SlimQuantW4A8Int8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class W8A8Int8LinearMethod(LinearMethodBase):
class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
......@@ -218,8 +218,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
bias=bias)
class W8A8Int8MoEMethod:
"""MoE method for INT8.
class SlimQuantW4A8Int8MoEMethod:
"""MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
......@@ -355,7 +355,7 @@ class W8A8Int8MoEMethod:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `W8A8Int8MoeMethod` yet.")
"EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......
......@@ -181,7 +181,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8","awq_marlin"
"quark", "ptpc_fp8", "moe_wna16", "slimquant_w4a8","w8a8_int8","awq_marlin"
]
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment