Commit 5852d3a0 authored by Li Zhao's avatar Li Zhao Committed by Myle Ott
Browse files

Add adaptive softmax changes for lstm model

parent 343819f9
...@@ -10,7 +10,7 @@ import torch.nn as nn ...@@ -10,7 +10,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options, utils from fairseq import options, utils
from fairseq.modules import AdaptiveSoftmax
from . import ( from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model,
register_model_architecture, register_model_architecture,
...@@ -49,6 +49,9 @@ class LSTMModel(FairseqModel): ...@@ -49,6 +49,9 @@ class LSTMModel(FairseqModel):
help='decoder output embedding dimension') help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='BOOL', parser.add_argument('--decoder-attention', type=str, metavar='BOOL',
help='decoder attention') help='decoder attention')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
# Granular dropout settings (if not specified these default to --dropout) # Granular dropout settings (if not specified these default to --dropout)
parser.add_argument('--encoder-dropout-in', type=float, metavar='D', parser.add_argument('--encoder-dropout-in', type=float, metavar='D',
...@@ -145,6 +148,10 @@ class LSTMModel(FairseqModel): ...@@ -145,6 +148,10 @@ class LSTMModel(FairseqModel):
encoder_output_units=encoder.output_units, encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed, pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed, share_input_output_embed=args.share_decoder_input_output_embed,
adaptive_softmax_cutoff=(
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
) )
return cls(encoder, decoder) return cls(encoder, decoder)
...@@ -184,6 +191,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -184,6 +191,7 @@ class LSTMEncoder(FairseqEncoder):
if bidirectional: if bidirectional:
self.output_units *= 2 self.output_units *= 2
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
if self.left_pad: if self.left_pad:
# convert left-padding to right-padding # convert left-padding to right-padding
...@@ -288,7 +296,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -288,7 +296,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None, encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False, share_input_output_embed=False, adaptive_softmax_cutoff=None,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout_in = dropout_in self.dropout_in = dropout_in
...@@ -297,6 +305,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -297,6 +305,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.share_input_output_embed = share_input_output_embed self.share_input_output_embed = share_input_output_embed
self.need_attn = True self.need_attn = True
self.adaptive_softmax = None
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
if pretrained_embed is None: if pretrained_embed is None:
...@@ -319,9 +328,14 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -319,9 +328,14 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None
if hidden_size != out_embed_dim: if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim) self.additional_fc = Linear(hidden_size, out_embed_dim)
if not self.share_input_output_embed: if adaptive_softmax_cutoff is not None:
# setting adaptive_softmax dropout to dropout_out for now but can be redefined
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, embed_dim, adaptive_softmax_cutoff,
dropout=dropout_out)
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out'] encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
...@@ -399,14 +413,14 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -399,14 +413,14 @@ class LSTMDecoder(FairseqIncrementalDecoder):
attn_scores = None attn_scores = None
# project back to size of vocabulary # project back to size of vocabulary
if hasattr(self, 'additional_fc'): if self.adaptive_softmax is None:
x = self.additional_fc(x) if hasattr(self, 'additional_fc'):
x = F.dropout(x, p=self.dropout_out, training=self.training) x = self.additional_fc(x)
if self.share_input_output_embed: x = F.dropout(x, p=self.dropout_out, training=self.training)
x = F.linear(x, self.embed_tokens.weight) if self.share_input_output_embed:
else: x = F.linear(x, self.embed_tokens.weight)
x = self.fc_out(x) else:
x = self.fc_out(x)
return x, attn_scores return x, attn_scores
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
...@@ -483,7 +497,7 @@ def base_architecture(args): ...@@ -483,7 +497,7 @@ def base_architecture(args):
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args): def lstm_wiseman_iwslt_de_en(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment