Unverified Commit d8ab6011 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Overlap qk norm with two streams (#5977)

parent 6579cd7d
......@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module):
reduce_results: bool = True,
layer_id: int = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.layer_id = layer_id
......@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix=add_prefix("attn_mha", prefix),
)
self.alt_stream = alt_stream
self.w_kc = None
self.w_vc = None
self.w_scale = None
......@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module):
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q)
k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
with torch.cuda.stream(self.alt_stream):
k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
......@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_nope_out = q_nope_out.transpose(0, 1)
k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.attention_backend == "fa3":
......@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_id=layer_id,
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
)
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
......@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module):
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
)
self.alt_stream = torch.cuda.Stream()
self.layers = nn.ModuleList(
[
DeepseekV2DecoderLayer(
......@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module):
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
alt_stream=self.alt_stream,
)
for layer_id in range(config.num_hidden_layers)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment