Unverified Commit 00aec6ad authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Apply dsv3_fused_a_gemm kernel (#7635)

parent 1a08358a
...@@ -96,6 +96,7 @@ from sglang.srt.utils import ( ...@@ -96,6 +96,7 @@ from sglang.srt.utils import (
bind_or_assign, bind_or_assign,
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
get_device_sm,
get_int_env_var, get_int_env_var,
is_cpu, is_cpu,
is_cuda, is_cuda,
...@@ -112,7 +113,7 @@ _is_cpu_amx_available = cpu_has_amx_support() ...@@ -112,7 +113,7 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, dsv3_fused_a_gemm, merge_state_v2
elif _is_cpu and _is_cpu_amx_available: elif _is_cpu and _is_cpu_amx_available:
pass pass
else: else:
...@@ -875,6 +876,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -875,6 +876,15 @@ class DeepseekV2AttentionMLA(nn.Module):
weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]] weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
) )
self.use_min_latency_fused_a_gemm = (
hasattr(self, "fused_qkv_a_proj_with_mqa")
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
and is_cuda
and get_device_sm() >= 90
)
self.qkv_proj_with_rope_is_int8 = ( self.qkv_proj_with_rope_is_int8 = (
hasattr(self, "fused_qkv_a_proj_with_mqa") hasattr(self, "fused_qkv_a_proj_with_mqa")
and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8 and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
...@@ -1114,7 +1124,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1114,7 +1124,13 @@ class DeepseekV2AttentionMLA(nn.Module):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
)
else:
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
) )
k_nope = latent_cache[..., : self.kv_lora_rank] k_nope = latent_cache[..., : self.kv_lora_rank]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment