Unverified Commit 33b7532e authored by Wei Fang's avatar Wei Fang Committed by GitHub
Browse files

Fix longformer attention mask type casting when using apex (#4574)

* Fix longformer attention mask casting when using apex

* remove extra type casting
parent 56ee2560
......@@ -348,9 +348,7 @@ class LongformerSelfAttention(nn.Module):
selected_v[selection_padding_mask_nonzeros] = v[extra_attention_mask_nonzeros]
# use `matmul` because `einsum` crashes sometimes with fp16
# attn = torch.einsum('blhs,bshd->blhd', (selected_attn_probs, selected_v))
attn = torch.matmul(
selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)
).transpose(1, 2)
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2)).transpose(1, 2)
attn_probs = attn_probs.narrow(
-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch
).contiguous()
......@@ -414,7 +412,7 @@ class LongformerSelfAttention(nn.Module):
]
attn[extra_attention_mask_nonzeros[::-1]] = nonzero_selected_attn.view(
len(selection_padding_mask_nonzeros[0]), -1
).type_as(hidden_states)
)
context_layer = attn.transpose(0, 1)
if self.output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment