Commit c5378602 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Share input/output embed

parent 907ca927
...@@ -135,7 +135,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -135,7 +135,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1): attention=True, dropout=0.1, share_embed=False):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.dropout = dropout self.dropout = dropout
...@@ -169,7 +169,14 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -169,7 +169,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
if attention[i] else None) if attention[i] else None)
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) if share_embed:
assert out_embed_dim == embed_dim, \
"Shared embed weights implies same dimensions " \
" out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
self.fc3.weight = self.embed_tokens.weight
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, input_tokens, encoder_out): def forward(self, input_tokens, encoder_out):
# split and transpose encoder outputs # split and transpose encoder outputs
...@@ -372,6 +379,7 @@ def parse_arch(args): ...@@ -372,6 +379,7 @@ def parse_arch(args):
args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20') args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True') args.decoder_attention = getattr(args, 'decoder_attention', 'True')
args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
return args return args
...@@ -391,5 +399,6 @@ def build_model(args, src_dict, dst_dict): ...@@ -391,5 +399,6 @@ def build_model(args, src_dict, dst_dict):
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_target_positions, max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed
) )
return FConvModel(encoder, decoder) return FConvModel(encoder, decoder)
...@@ -173,4 +173,8 @@ def add_model_args(parser): ...@@ -173,4 +173,8 @@ def add_model_args(parser):
help='dropout probability') help='dropout probability')
group.add_argument('--label-smoothing', default=0, type=float, metavar='D', group.add_argument('--label-smoothing', default=0, type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing') help='epsilon for label smoothing, 0 means no label smoothing')
group.add_argument('--share-input-output-embed', action='store_true',
help="Share input and output embeddings, "
"requires --decoder-out-embed-dim and --decoder-embed-dim be equal ")
return group return group
...@@ -173,6 +173,8 @@ def _upgrade_args(args): ...@@ -173,6 +173,8 @@ def _upgrade_args(args):
if not hasattr(args, 'max_source_positions'): if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
return args return args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment