"vscode:/vscode.git/clone" did not exist on "349a499f339fd0e758ee8bd54c09ee073cfbaf3b"
Commit ee8bcb17 authored by Sergey Edunov's avatar Sergey Edunov Committed by Facebook Github Bot
Browse files

Fix of MHA for TPUs (#636)

Summary:
Multi-Head attention is currently not TPU-friendly, specifically .data_ptr() is not supported and should not be used. Also there are potential issues with correctness of existing code (e.g. data_ptr() can point to the same storage for different tensors).  Rather than rely on data_ptr() we should explicitly set self_attention or encoder_decoder_attention flags.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/636

Reviewed By: myleott

Differential Revision: D15709898

Pulled By: edunov

fbshipit-source-id: f931713193c51be848a5de20da730ac3a3ce0187
parent 4868c182
......@@ -563,7 +563,7 @@ class LightConvDecoderLayer(nn.Module):
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
dropout=args.attention_dropout, encoder_decoder_attention=True
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
......
......@@ -512,7 +512,7 @@ class TransformerEncoderLayer(nn.Module):
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout,
dropout=args.attention_dropout, self_attention=True
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout
......@@ -608,6 +608,7 @@ class TransformerDecoderLayer(nn.Module):
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
......@@ -631,7 +632,7 @@ class TransformerDecoderLayer(nn.Module):
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
dropout=args.attention_dropout, encoder_decoder_attention=True
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
......
......@@ -19,7 +19,9 @@ class MultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
add_bias_kv=False, add_zero_attn=False, self_attention=False,
encoder_decoder_attention=False):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
......@@ -32,6 +34,13 @@ class MultiheadAttention(nn.Module):
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
self.encoder_decoder_attention = encoder_decoder_attention
assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
'value to be of the same size'
if self.qkv_same_dim:
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
else:
......@@ -82,16 +91,12 @@ class MultiheadAttention(nn.Module):
need_weights=True, static_kv=False, attn_mask=None):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Timesteps can be masked by supplying a T x T mask in the
Timesteps can be masked by supplying a T x T mask in the
`attn_mask` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
......@@ -102,15 +107,15 @@ class MultiheadAttention(nn.Module):
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if qkv_same:
if self.self_attention:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif kv_same:
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
......
......@@ -50,6 +50,7 @@ class TransformerSentenceEncoderLayer(nn.Module):
dropout=attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True
)
# layer norm associated with the self attention layer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment