Commit ee58c1bf authored by 王敏's avatar 王敏
Browse files

[feat]W8A8适配deepseek以及mtp

parent bc387d5a
...@@ -913,6 +913,7 @@ class ModelConfig: ...@@ -913,6 +913,7 @@ class ModelConfig:
"mxfp4", "mxfp4",
"cpu_awq", "cpu_awq",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"slimquant_marlin",
"slimquant_compressed_tensors_marlin", "slimquant_compressed_tensors_marlin",
] ]
quantization_methods = [ quantization_methods = [
......
...@@ -280,6 +280,7 @@ if TYPE_CHECKING: ...@@ -280,6 +280,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
...@@ -1788,6 +1789,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1788,6 +1789,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in
......
...@@ -872,7 +872,7 @@ class FusedMoEParallelConfig: ...@@ -872,7 +872,7 @@ class FusedMoEParallelConfig:
use_ep: bool # whether to use EP or not use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication all2all_backend: str # all2all backend for MoE communication
is_sequence_parallel: bool # whether sequence parallelism is used #is_sequence_parallel: bool # whether sequence parallelism is used
enable_eplb: bool # whether to enable expert load balancing enable_eplb: bool # whether to enable expert load balancing
@property @property
......
...@@ -672,7 +672,8 @@ class FusedMoE(CustomOp): ...@@ -672,7 +672,8 @@ class FusedMoE(CustomOp):
def maybe_init_modular_kernel(self) -> None: def maybe_init_modular_kernel(self) -> None:
# NOTE(rob): WIP refactor. For quant methods that own the MK # NOTE(rob): WIP refactor. For quant methods that own the MK
# we create the MK during process_weights_after_loading. # we create the MK during process_weights_after_loading.
if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic: #if self.quant_method.supports_internal_mk or self.quant_method.is_monolithic:
if self.quant_method.is_monolithic:
return None return None
self.ensure_moe_quant_config_init() self.ensure_moe_quant_config_init()
...@@ -1930,7 +1931,6 @@ class FusedMoE(CustomOp): ...@@ -1930,7 +1931,6 @@ class FusedMoE(CustomOp):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
) )
if has_separate_shared_experts: if has_separate_shared_experts:
......
...@@ -38,6 +38,7 @@ QuantizationMethods = Literal[ ...@@ -38,6 +38,7 @@ QuantizationMethods = Literal[
"blockwise_int8", "blockwise_int8",
"slimquant_w4a8", "slimquant_w4a8",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"slimquant_marlin",
"slimquant_compressed_tensors_marlin", "slimquant_compressed_tensors_marlin",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -177,6 +178,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -177,6 +178,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"slimquant_w4a8":SlimQuantW4A8Int8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, "slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
"slimquant_marlin":SlimQuantCompressedTensorsMarlinConfig,
"slimquant_compressed_tensors_marlin":SlimQuantCompressedTensorsMarlinConfig, "slimquant_compressed_tensors_marlin":SlimQuantCompressedTensorsMarlinConfig,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
......
...@@ -50,6 +50,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -50,6 +50,8 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
kv_cache_scheme: Optional[dict[str, Any]] = None, kv_cache_scheme: Optional[dict[str, Any]] = None,
config: Optional[dict[str, Any]] = None, config: Optional[dict[str, Any]] = None,
transform_config: Optional[dict[str, Any]] = None, transform_config: Optional[dict[str, Any]] = None,
total_num_heads: int | None = None,
total_num_kv_heads: int | None = None,
): ):
super().__init__( super().__init__(
target_scheme_map, target_scheme_map,
...@@ -61,6 +63,9 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig): ...@@ -61,6 +63,9 @@ class SlimQuantCompressedTensorsMarlinConfig(CompressedTensorsConfig):
config, config,
transform_config transform_config
) )
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
@classmethod @classmethod
def override_quantization_method( def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
......
...@@ -147,52 +147,15 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -147,52 +147,15 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False) layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False) layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
renormalize: bool, use_nn_moe: bool | None = False,
use_grouped_topk: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_group: Optional[int] = None, from vllm.model_executor.layers.fused_moe import fused_experts
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias)
return fused_experts_impl_int8_marlin( return fused_experts_impl_int8_marlin(
hidden_states=x, hidden_states=x,
...@@ -201,16 +164,16 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -201,16 +164,16 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=layer.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
use_int8_w8a8=True, use_int8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
global_num_experts=global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
shared_output=shared_output, )
routed_scaling_factor=routed_scaling_factor) \ No newline at end of file
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment