Commit 461a366d authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support different embed dim in Transformer decoder

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/811

Differential Revision: D15880880

Pulled By: myleott

fbshipit-source-id: c47e09a90c945aca82b26edb4a8af93e063d5b00
parent 00ac823e
...@@ -639,8 +639,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -639,8 +639,12 @@ class TransformerDecoderLayer(nn.Module):
self.encoder_attn_layer_norm = None self.encoder_attn_layer_norm = None
else: else:
self.encoder_attn = MultiheadAttention( self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads, self.embed_dim,
dropout=args.attention_dropout, encoder_decoder_attention=True args.decoder_attention_heads,
kdim=getattr(args, 'encoder_embed_dim', None),
vdim=getattr(args, 'encoder_embed_dim', None),
dropout=args.attention_dropout,
encoder_decoder_attention=True,
) )
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment