"server/vscode:/vscode.git/clone" did not exist on "521d0d990f1624d4821d6a1763805df312306fa9"
Unverified Commit 84f2e4a0 authored by AniZpZ's avatar AniZpZ Committed by GitHub
Browse files

fix awq and dsv3 fused gemm compatible (#7735)

parent 8f844db6
...@@ -336,11 +336,17 @@ class DeepseekV2MoE(nn.Module): ...@@ -336,11 +336,17 @@ class DeepseekV2MoE(nn.Module):
else {} else {}
), ),
) )
is_packed_weight = (
self.shared_experts.gate_up_proj.quant_method.quant_config.get_name()
in ["awq", "moe_wna16"]
)
self.shared_experts_is_int8 = ( self.shared_experts_is_int8 = (
self.shared_experts.gate_up_proj.weight.dtype == torch.int8 not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
) )
self.shared_experts_is_fp8 = ( self.shared_experts_is_fp8 = (
self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
) )
if self.shared_experts_is_fp8: if self.shared_experts_is_fp8:
assert ( assert (
...@@ -894,8 +900,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -894,8 +900,13 @@ class DeepseekV2AttentionMLA(nn.Module):
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]] weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
) )
is_packed_weight = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
in ["awq", "moe_wna16"]
)
self.use_min_latency_fused_a_gemm = ( self.use_min_latency_fused_a_gemm = (
hasattr(self, "fused_qkv_a_proj_with_mqa") hasattr(self, "fused_qkv_a_proj_with_mqa")
and not is_packed_weight
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16 and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112 and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168 and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
...@@ -905,10 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -905,10 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module):
self.qkv_proj_with_rope_is_int8 = ( self.qkv_proj_with_rope_is_int8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa") hasattr(self, "fused_qkv_a_proj_with_mqa")
and not is_packed_weight
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8 and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
) )
self.qkv_proj_with_rope_is_fp8 = ( self.qkv_proj_with_rope_is_fp8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa") hasattr(self, "fused_qkv_a_proj_with_mqa")
and not is_packed_weight
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment