Unverified Commit d8ab6011 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Overlap qk norm with two streams (#5977)

parent 6579cd7d
...@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module):
reduce_results: bool = True, reduce_results: bool = True,
layer_id: int = None, layer_id: int = None,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix=add_prefix("attn_mha", prefix), prefix=add_prefix("attn_mha", prefix),
) )
self.alt_stream = alt_stream
self.w_kc = None self.w_kc = None
self.w_vc = None self.w_vc = None
self.w_scale = None self.w_scale = None
...@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module):
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
) )
q = self.q_a_layernorm(q) k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
with torch.cuda.stream(self.alt_stream):
k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else: else:
q = self.q_proj(hidden_states)[0].view( q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim -1, self.num_local_heads, self.qk_head_dim
) )
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
...@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.attention_backend == "fa3": if self.attention_backend == "fa3":
...@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
is_nextn: bool = False, is_nextn: bool = False,
prefix: str = "", prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
reduce_results=False, reduce_results=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
) )
self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
...@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module):
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not global_server_args_dict["enable_dp_attention"],
) )
self.alt_stream = torch.cuda.Stream()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DeepseekV2DecoderLayer( DeepseekV2DecoderLayer(
...@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module):
layer_id, layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix), prefix=add_prefix(f"layers.{layer_id}", prefix),
alt_stream=self.alt_stream,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment