Commit f472d141 authored by Stephen Roller's avatar Stephen Roller Committed by Myle Ott
Browse files

Support tied embeddings in LSTM encoder/decoder

parent a7d0bd0e
......@@ -59,6 +59,12 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
@classmethod
def build_model(cls, args, task):
......@@ -78,10 +84,39 @@ class LSTMModel(FairseqModel):
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim)
if args.share_all_embeddings:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
raise RuntimeError('--share-all-embeddings requires a joint dictionary')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError(
'--share-all-embed not compatible with --decoder-embed-path'
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed = pretrained_encoder_embed
args.share_decoder_input_output_embed = True
else:
# separate decoder input embeddings
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path,
task.target_dictionary,
args.decoder_embed_dim
)
# one last double check of parameter combinations
if args.share_decoder_input_output_embed and (
args.decoder_embed_dim != args.decoder_out_embed_dim):
raise RuntimeError(
'--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim'
)
encoder = LSTMEncoder(
dictionary=task.source_dictionary,
......@@ -105,6 +140,7 @@ class LSTMModel(FairseqModel):
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
)
return cls(encoder, decoder)
......@@ -251,11 +287,13 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False,
):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
......@@ -279,7 +317,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None
if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
if not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
......@@ -358,7 +397,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if hasattr(self, 'additional_fc'):
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
x = self.fc_out(x)
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = self.fc_out(x)
return x, attn_scores
......@@ -431,6 +473,8 @@ def base_architecture(args):
args.decoder_attention = getattr(args, 'decoder_attention', '1')
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment