Unverified Commit bd0873af authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] fix attn_mask_type for inter_attention (#565)


Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
parent acd811aa
...@@ -619,7 +619,6 @@ class TransformerLayer(torch.nn.Module): ...@@ -619,7 +619,6 @@ class TransformerLayer(torch.nn.Module):
inter_attention_outputs = self.inter_attention( inter_attention_outputs = self.inter_attention(
hidden_states, hidden_states,
attention_mask=enc_dec_attn_mask, attention_mask=enc_dec_attn_mask,
attn_mask_type=self_attn_mask_type,
encoder_output=encoder_output, encoder_output=encoder_output,
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment