Unverified Commit 6ba98d43 authored by jomitchellnv's avatar jomitchellnv Committed by GitHub
Browse files

fix: fixes multi head attention for context parallel: rotary embedding to use...


fix: fixes multi head attention for context parallel: rotary embedding to use padded cu_seq_lens (#2077)

fix: fixes mha to use padded cu_seq_lens during cp
Signed-off-by: default avatarJonathan Mitchell <jomitchell@nvidia.com>
parent c654e4fe
......@@ -907,12 +907,19 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
if pad_between_seqs:
rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded
rotary_pos_cu_seq_lens_kv = cu_seqlens_kv_padded
else:
rotary_pos_cu_seq_lens_q = cu_seqlens_q
rotary_pos_cu_seq_lens_kv = cu_seqlens_kv
query_layer = apply_rotary_pos_emb(
query_layer,
q_pos_emb,
self.qkv_format,
fused=True,
cu_seqlens=cu_seqlens_q,
cu_seqlens=rotary_pos_cu_seq_lens_q,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved,
......@@ -922,7 +929,7 @@ class MultiheadAttention(torch.nn.Module):
k_pos_emb,
self.qkv_format,
fused=True,
cu_seqlens=cu_seqlens_kv,
cu_seqlens=rotary_pos_cu_seq_lens_kv,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment