Unverified Commit a326e351 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] Fix issues with cross attention (#1069)


Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent cc329b79
...@@ -5778,7 +5778,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5778,7 +5778,8 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
attention_mask is not None attention_mask is not None
), "Please provide attention_mask for padding!" ), "Please provide attention_mask for padding!"
if max_seqlen_q == max_seqlen_kv: if self.attention_type == "self":
assert max_seqlen_q == max_seqlen_kv
cu_seqlens_q = get_cu_seqlens(attention_mask) cu_seqlens_q = get_cu_seqlens(attention_mask)
cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv = cu_seqlens_q
else: else:
......
...@@ -652,7 +652,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -652,7 +652,7 @@ class TransformerLayer(torch.nn.Module):
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
window_size=enc_dec_window_size, window_size=window_size,
inference_params=inference_params, inference_params=inference_params,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
...@@ -679,6 +679,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -679,6 +679,8 @@ class TransformerLayer(torch.nn.Module):
inter_attention_outputs = self.inter_attention( inter_attention_outputs = self.inter_attention(
hidden_states, hidden_states,
attention_mask=enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
attn_mask_type=enc_dec_attn_mask_type,
window_size=enc_dec_window_size,
encoder_output=encoder_output, encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment