"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "8079d98617694b6fc896dedf4c83e7b25dc716cb"
Commit b18a3126 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Faster masking in MultiheadAttention

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/612

Differential Revision: D15541377

Pulled By: myleott

fbshipit-source-id: 4762516a3b545d03bc81d3660f47827e15466dce
parent c97978a2
...@@ -201,10 +201,10 @@ class MultiheadAttention(nn.Module): ...@@ -201,10 +201,10 @@ class MultiheadAttention(nn.Module):
attn_weights.float() attn_weights.float()
).type_as(attn_weights) ).type_as(attn_weights)
else: else:
attn_weights = attn_weights.float().masked_fill( attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'), float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back )
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = utils.softmax( attn_weights = utils.softmax(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment