Commit 9196c0b6 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

LSTM improvements (fixes #414)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/470

Differential Revision: D13803964

Pulled By: myleott

fbshipit-source-id: 91b66599e9a539833fcedea07c608b349ba3b449
parent d0ebcec4
...@@ -32,6 +32,8 @@ class LSTMModel(FairseqModel): ...@@ -32,6 +32,8 @@ class LSTMModel(FairseqModel):
help='encoder embedding dimension') help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR', parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding') help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-freeze-embed', action='store_true',
help='freeze encoder embeddings')
parser.add_argument('--encoder-hidden-size', type=int, metavar='N', parser.add_argument('--encoder-hidden-size', type=int, metavar='N',
help='encoder hidden size') help='encoder hidden size')
parser.add_argument('--encoder-layers', type=int, metavar='N', parser.add_argument('--encoder-layers', type=int, metavar='N',
...@@ -42,6 +44,8 @@ class LSTMModel(FairseqModel): ...@@ -42,6 +44,8 @@ class LSTMModel(FairseqModel):
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR', parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding') help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-freeze-embed', action='store_true',
help='freeze decoder embeddings')
parser.add_argument('--decoder-hidden-size', type=int, metavar='N', parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
help='decoder hidden size') help='decoder hidden size')
parser.add_argument('--decoder-layers', type=int, metavar='N', parser.add_argument('--decoder-layers', type=int, metavar='N',
...@@ -53,6 +57,12 @@ class LSTMModel(FairseqModel): ...@@ -53,6 +57,12 @@ class LSTMModel(FairseqModel):
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
# Granular dropout settings (if not specified these default to --dropout) # Granular dropout settings (if not specified these default to --dropout)
parser.add_argument('--encoder-dropout-in', type=float, metavar='D', parser.add_argument('--encoder-dropout-in', type=float, metavar='D',
...@@ -63,12 +73,6 @@ class LSTMModel(FairseqModel): ...@@ -63,12 +73,6 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder input embedding') help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D', parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output') help='dropout probability for decoder output')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -130,6 +134,11 @@ class LSTMModel(FairseqModel): ...@@ -130,6 +134,11 @@ class LSTMModel(FairseqModel):
'--decoder-embed-dim to match --decoder-out-embed-dim' '--decoder-embed-dim to match --decoder-out-embed-dim'
) )
if args.encoder_freeze_embed:
pretrained_encoder_embed.weight.requires_grad = False
if args.decoder_freeze_embed:
pretrained_decoder_embed.weight.requires_grad = False
encoder = LSTMEncoder( encoder = LSTMEncoder(
dictionary=task.source_dictionary, dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
...@@ -149,7 +158,6 @@ class LSTMModel(FairseqModel): ...@@ -149,7 +158,6 @@ class LSTMModel(FairseqModel):
dropout_in=args.decoder_dropout_in, dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out, dropout_out=args.decoder_dropout_out,
attention=options.eval_bool(args.decoder_attention), attention=options.eval_bool(args.decoder_attention),
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units, encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed, pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed, share_input_output_embed=args.share_decoder_input_output_embed,
...@@ -222,8 +230,8 @@ class LSTMEncoder(FairseqEncoder): ...@@ -222,8 +230,8 @@ class LSTMEncoder(FairseqEncoder):
state_size = 2 * self.num_layers, bsz, self.hidden_size state_size = 2 * self.num_layers, bsz, self.hidden_size
else: else:
state_size = self.num_layers, bsz, self.hidden_size state_size = self.num_layers, bsz, self.hidden_size
h0 = x.data.new(*state_size).zero_() h0 = x.new_zeros(*state_size)
c0 = x.data.new(*state_size).zero_() c0 = x.new_zeros(*state_size)
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0)) packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout # unpack outputs and apply dropout
...@@ -263,11 +271,11 @@ class LSTMEncoder(FairseqEncoder): ...@@ -263,11 +271,11 @@ class LSTMEncoder(FairseqEncoder):
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
def __init__(self, input_embed_dim, output_embed_dim): def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):
super().__init__() super().__init__()
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False) self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias)
self.output_proj = Linear(input_embed_dim + output_embed_dim, output_embed_dim, bias=False) self.output_proj = Linear(input_embed_dim + source_embed_dim, output_embed_dim, bias=bias)
def forward(self, input, source_hids, encoder_padding_mask): def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim # input: bsz x input_embed_dim
...@@ -300,7 +308,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -300,7 +308,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def __init__( def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None, encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False, adaptive_softmax_cutoff=None, share_input_output_embed=False, adaptive_softmax_cutoff=None,
): ):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -319,18 +327,23 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -319,18 +327,23 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.embed_tokens = pretrained_embed self.embed_tokens = pretrained_embed
self.encoder_output_units = encoder_output_units self.encoder_output_units = encoder_output_units
assert encoder_output_units == hidden_size, \ if encoder_output_units != hidden_size:
'encoder_output_units ({}) != hidden_size ({})'.format(encoder_output_units, hidden_size) self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size)
# TODO another Linear layer if not equal self.encoder_cell_proj = Linear(encoder_output_units, hidden_size)
else:
self.encoder_hidden_proj = self.encoder_cell_proj = None
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LSTMCell( LSTMCell(
input_size=encoder_output_units + embed_dim if layer == 0 else hidden_size, input_size=hidden_size + embed_dim if layer == 0 else hidden_size,
hidden_size=hidden_size, hidden_size=hidden_size,
) )
for layer in range(num_layers) for layer in range(num_layers)
]) ])
self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None if attention:
# TODO make bias configurable
self.attention = AttentionLayer(hidden_size, encoder_output_units, hidden_size, bias=False)
else:
self.attention = None
if hidden_size != out_embed_dim: if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim) self.additional_fc = Linear(hidden_size, out_embed_dim)
if adaptive_softmax_cutoff is not None: if adaptive_softmax_cutoff is not None:
...@@ -349,7 +362,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -349,7 +362,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
bsz, seqlen = prev_output_tokens.size() bsz, seqlen = prev_output_tokens.size()
# get outputs from encoder # get outputs from encoder
encoder_outs, _, _ = encoder_out[:3] encoder_outs, encoder_hiddens, encoder_cells = encoder_out[:3]
srclen = encoder_outs.size(0) srclen = encoder_outs.size(0)
# embed tokens # embed tokens
...@@ -364,13 +377,15 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -364,13 +377,15 @@ class LSTMDecoder(FairseqIncrementalDecoder):
if cached_state is not None: if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state prev_hiddens, prev_cells, input_feed = cached_state
else: else:
_, encoder_hiddens, encoder_cells = encoder_out[:3]
num_layers = len(self.layers) num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)]
input_feed = x.data.new(bsz, self.encoder_output_units).zero_() if self.encoder_hidden_proj is not None:
prev_hiddens = [self.encoder_hidden_proj(x) for x in prev_hiddens]
prev_cells = [self.encoder_cell_proj(x) for x in prev_cells]
input_feed = x.new_zeros(bsz, self.hidden_size)
attn_scores = x.data.new(srclen, seqlen, bsz).zero_() attn_scores = x.new_zeros(srclen, seqlen, bsz)
outs = [] outs = []
for j in range(seqlen): for j in range(seqlen):
# input feeding: concatenate context vector from previous time step # input feeding: concatenate context vector from previous time step
...@@ -402,7 +417,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -402,7 +417,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
# cache previous states (no-op except during incremental generation) # cache previous states (no-op except during incremental generation)
utils.set_incremental_state( utils.set_incremental_state(
self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed)) self, incremental_state, 'cached_state',
(prev_hiddens, prev_cells, input_feed),
)
# collect outputs across time steps # collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size) x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
...@@ -486,6 +503,7 @@ def base_architecture(args): ...@@ -486,6 +503,7 @@ def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_freeze_embed = getattr(args, 'encoder_freeze_embed', False)
args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim) args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim)
args.encoder_layers = getattr(args, 'encoder_layers', 1) args.encoder_layers = getattr(args, 'encoder_layers', 1)
args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False) args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False)
...@@ -493,6 +511,7 @@ def base_architecture(args): ...@@ -493,6 +511,7 @@ def base_architecture(args):
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False)
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim) args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 1) args.decoder_layers = getattr(args, 'decoder_layers', 1)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment