Unverified Commit 33b7532e authored by Wei Fang's avatar Wei Fang Committed by GitHub
Browse files

Fix longformer attention mask type casting when using apex (#4574)

* Fix longformer attention mask casting when using apex

* remove extra type casting
parent 56ee2560
...@@ -348,9 +348,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -348,9 +348,7 @@ class LongformerSelfAttention(nn.Module):
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros] selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
# use `matmul` because `einsum` crashes sometimes with fp16 # use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v)) # attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
attn = torch.matmul( attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2)
selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)
).transpose(1, 2)
attn_probs = attn_probs.narrow( attn_probs = attn_probs.narrow(
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch -1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
).contiguous() ).contiguous()
...@@ -414,7 +412,7 @@ class LongformerSelfAttention(nn.Module): ...@@ -414,7 +412,7 @@ class LongformerSelfAttention(nn.Module):
] ]
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view( attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
len(selection_padding_mask_nonzeros[0]), -1 len(selection_padding_mask_nonzeros[0]), -1
).type_as(hidden_states) )
context_layer = attn.transpose(0, 1) context_layer = attn.transpose(0, 1)
if self.output_attentions: if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment