Unverified Commit a326e351 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] Fix issues with cross attention (#1069)


Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent cc329b79
......@@ -5778,7 +5778,8 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
attention_mask is not None
), "Please provide attention_mask for padding!"
if max_seqlen_q == max_seqlen_kv:
if self.attention_type == "self":
assert max_seqlen_q == max_seqlen_kv
cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q
else:
......
......@@ -652,7 +652,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states,
attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type,
window_size=enc_dec_window_size,
window_size=window_size,
inference_params=inference_params,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......@@ -679,6 +679,8 @@ class TransformerLayer(torch.nn.Module):
inter_attention_outputs = self.inter_attention(
hidden_states,
attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size,
encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment