Unverified Commit b5822651 authored by AniZpZ's avatar AniZpZ Committed by GitHub
Browse files

fix dsv3 fused proj check (#7738)

parent 2c4feaf3
...@@ -336,10 +336,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -336,10 +336,12 @@ class DeepseekV2MoE(nn.Module):
else {} else {}
), ),
) )
is_packed_weight = ( is_packed_weight = hasattr(
self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() self.shared_experts.gate_up_proj.quant_method, "quant_config"
in ["awq", "moe_wna16"] ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
) "awq",
"moe_wna16",
}
self.shared_experts_is_int8 = ( self.shared_experts_is_int8 = (
not is_packed_weight not is_packed_weight
and self.shared_experts.gate_up_proj.weight.dtype == torch.int8 and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
...@@ -891,21 +893,20 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -891,21 +893,20 @@ class DeepseekV2AttentionMLA(nn.Module):
# If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
# which requires self.w_kc and self.w_vc to be packed. # which requires self.w_kc and self.w_vc to be packed.
# If not, we will use torch.bmm and weight shouldn't be packed in this case # If not, we will use torch.bmm and weight shouldn't be packed in this case
if ( has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
hasattr(self, "fused_qkv_a_proj_with_mqa") if has_fused_proj and _is_cpu and _is_cpu_amx_available:
and _is_cpu
and _is_cpu_amx_available
):
self.quant_method = PackWeightMethod( self.quant_method = PackWeightMethod(
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]] weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
) )
is_packed_weight = ( is_packed_weight = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name() has_fused_proj
in ["awq", "moe_wna16"] and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
in {"awq", "moe_wna16"}
) )
self.use_min_latency_fused_a_gemm = ( self.use_min_latency_fused_a_gemm = (
hasattr(self, "fused_qkv_a_proj_with_mqa") has_fused_proj
and not is_packed_weight and not is_packed_weight
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16 and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112 and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
...@@ -915,12 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -915,12 +916,12 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.qkv_proj_with_rope_is_int8 = ( self.qkv_proj_with_rope_is_int8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa") has_fused_proj
and not is_packed_weight and not is_packed_weight
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8 and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
) )
self.qkv_proj_with_rope_is_fp8 = ( self.qkv_proj_with_rope_is_fp8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa") has_fused_proj
and not is_packed_weight and not is_packed_weight
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment