Unverified Commit 21b8fb5c authored by Sergey Edunov's avatar Sergey Edunov Committed by GitHub
Browse files

Merge pull request #107 from facebookresearch/oss-merge-internal

Oss merge internal changes
parents 66415206 2f976aae
...@@ -9,7 +9,7 @@ import math ...@@ -9,7 +9,7 @@ import math
import torch.nn.functional as F import torch.nn.functional as F
from . import FairseqCriterion, register_criterion from . import FairseqCriterion, register_criterion
from fairseq import utils
@register_criterion('cross_entropy') @register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
...@@ -33,7 +33,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -33,7 +33,7 @@ class CrossEntropyCriterion(FairseqCriterion):
reduce=reduce) reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0] if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'sample_size': sample_size, 'sample_size': sample_size,
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import math import math
import torch import torch
from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
...@@ -24,6 +25,8 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -24,6 +25,8 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
norm = grad_input.size(-1) norm = grad_input.size(-1)
if weights is not None: if weights is not None:
if isinstance(grad_input, Variable) and not isinstance(weights, Variable):
weights = Variable(weights, requires_grad=False)
norm = weights.sum() norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input)) grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
...@@ -41,7 +44,10 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -41,7 +44,10 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None grad_input = ctx.grad_input
if not isinstance(grad_input, torch.autograd.Variable):
grad_input = utils.volatile_variable(grad_input)
return grad_input * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy') @register_criterion('label_smoothed_cross_entropy')
...@@ -73,8 +79,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -73,8 +79,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce) nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0] if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': nll_loss.data[0] if reduce else loss.data, 'nll_loss': utils.item(nll_loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'sample_size': sample_size, 'sample_size': sample_size,
} }
......
...@@ -116,7 +116,7 @@ def all_gather_list(data, max_size=4096): ...@@ -116,7 +116,7 @@ def all_gather_list(data, max_size=4096):
if len(enc) >= max_size: if len(enc) >= max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(len(enc))) raise ValueError('encoded data exceeds max_size: {}'.format(len(enc)))
in_buffer[0] = len(enc) in_buffer[0] = len(enc)
in_buffer[1:len(enc)+1] = torch.ByteTensor(enc) in_buffer[1:len(enc)+1] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda()) torch.distributed.all_gather(out_buffers, in_buffer.cuda())
......
...@@ -13,100 +13,21 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -13,100 +13,21 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary): def __init__(self, dictionary):
super().__init__(dictionary) super().__init__(dictionary)
self._is_incremental_eval = False
self._incremental_state = {}
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if self._is_incremental_eval: raise NotImplementedError
raise NotImplementedError
else:
raise NotImplementedError
def incremental_inference(self): def reorder_incremental_state(self, incremental_state, new_order):
"""Context manager for incremental inference. """Reorder incremental state.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
```
"""
class IncrementalInference(object):
def __init__(self, decoder):
self.decoder = decoder
def __enter__(self):
self.decoder.incremental_eval(True)
def __exit__(self, *args):
self.decoder.incremental_eval(False)
return IncrementalInference(self)
def incremental_eval(self, mode=True):
"""Sets the decoder and all children in incremental evaluation mode."""
assert self._is_incremental_eval != mode, \
'incremental_eval already set to mode {}'.format(mode)
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def apply_incremental_eval(module):
if module != self and hasattr(module, 'incremental_eval'):
module.incremental_eval(mode)
self.apply(apply_incremental_eval)
def get_incremental_state(self, key):
"""Return cached state or None if not in incremental inference mode."""
if self._is_incremental_eval and key in self._incremental_state:
return self._incremental_state[key]
return None
def set_incremental_state(self, key, value):
"""Cache state needed for incremental inference mode."""
if self._is_incremental_eval:
self._incremental_state[key] = value
return value
def clear_incremental_state(self):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if self._is_incremental_eval:
del self._incremental_state
self._incremental_state = {}
def apply_clear_incremental_state(module):
if module != self and hasattr(module, 'clear_incremental_state'):
module.clear_incremental_state()
self.apply(apply_clear_incremental_state)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
This should be called when the order of the input has changed from the This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams. order changes between time steps based on the selection of beams.
""" """
if self._is_incremental_eval: def apply_reorder_incremental_state(module):
def apply_reorder_incremental_state(module): if module != self and hasattr(module, 'reorder_incremental_state'):
if module != self and hasattr(module, 'reorder_incremental_state'): module.reorder_incremental_state(incremental_state, new_order)
module.reorder_incremental_state(new_order) self.apply(apply_reorder_incremental_state)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.data import LanguagePairDataset from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
...@@ -229,19 +230,13 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -229,19 +230,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# split and transpose encoder outputs # split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out) encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
# embed positions # embed tokens and combine with positional embeddings
positions = self.embed_positions(prev_output_tokens) x = self._embed_tokens(prev_output_tokens, incremental_state)
x += self.embed_positions(prev_output_tokens, incremental_state)
if self._is_incremental_eval:
# keep only the last token for incremental forward pass
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens and positions
x = self.embed_tokens(prev_output_tokens) + positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x target_embedding = x
...@@ -249,7 +244,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -249,7 +244,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self.fc1(x) x = self.fc1(x)
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = self._transpose_unless_incremental_eval(x) x = self._transpose_if_training(x, incremental_state)
# temporal convolutions # temporal convolutions
avg_attn_scores = None avg_attn_scores = None
...@@ -258,13 +253,14 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -258,13 +253,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x) x = conv(x, incremental_state)
x = conv.remove_future_timesteps(x) if incremental_state is None:
x = conv.remove_future_timesteps(x)
x = F.glu(x, dim=2) x = F.glu(x, dim=2)
# attention # attention
if attention is not None: if attention is not None:
x = self._transpose_unless_incremental_eval(x) x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b)) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
attn_scores = attn_scores / num_attn_layers attn_scores = attn_scores / num_attn_layers
...@@ -273,13 +269,13 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -273,13 +269,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
avg_attn_scores.add_(attn_scores) avg_attn_scores.add_(attn_scores)
x = self._transpose_unless_incremental_eval(x) x = self._transpose_if_training(x, incremental_state)
# residual # residual
x = (x + residual) * math.sqrt(0.5) x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C # T x B x C -> B x T x C
x = self._transpose_unless_incremental_eval(x) x = self._transpose_if_training(x, incremental_state)
# project back to size of vocabulary # project back to size of vocabulary
x = self.fc2(x) x = self.fc2(x)
...@@ -288,10 +284,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -288,10 +284,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores return x, avg_attn_scores
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -306,13 +298,19 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -306,13 +298,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict['decoder.version'] = torch.Tensor([1]) state_dict['decoder.version'] = torch.Tensor([1])
return state_dict return state_dict
def _split_encoder_out(self, encoder_out): def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None:
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
return self.embed_tokens(tokens)
def _split_encoder_out(self, encoder_out, incremental_state):
"""Split and transpose encoder outputs. """Split and transpose encoder outputs.
This is cached when doing incremental inference. This is cached when doing incremental inference.
""" """
cached_result = self.get_incremental_state('encoder_out') cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if cached_result: if cached_result is not None:
return cached_result return cached_result
# transpose only once to speed up attention layers # transpose only once to speed up attention layers
...@@ -320,12 +318,14 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -320,12 +318,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_a = encoder_a.transpose(1, 2).contiguous() encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b) result = (encoder_a, encoder_b)
return self.set_incremental_state('encoder_out', result) if incremental_state is not None:
utils.set_incremental_state(self, incremental_state, 'encoder_out', result)
return result
def _transpose_unless_incremental_eval(self, x): def _transpose_if_training(self, x, incremental_state):
if self._is_incremental_eval: if incremental_state is None:
return x x = x.transpose(0, 1)
return x.transpose(0, 1) return x
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
......
...@@ -183,12 +183,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -183,12 +183,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.additional_fc = Linear(embed_dim, out_embed_dim) self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if self._is_incremental_eval: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, prev_output_tokens, encoder_out):
bsz, seqlen = prev_output_tokens.size() bsz, seqlen = prev_output_tokens.size()
# get outputs from encoder # get outputs from encoder
...@@ -204,15 +201,15 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -204,15 +201,15 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(0, 1) x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental generation) # initialize previous states (or get from cache during incremental generation)
prev_hiddens = self.get_incremental_state('prev_hiddens') cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if not prev_hiddens: if cached_state is not None:
# first time step, initialize previous states prev_hiddens, prev_cells, input_feed = cached_state
prev_hiddens, prev_cells = self._init_prev_states(encoder_out)
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
else: else:
# previous states are cached _, encoder_hiddens, encoder_cells = encoder_out
prev_cells = self.get_incremental_state('prev_cells') num_layers = len(self.layers)
input_feed = self.get_incremental_state('input_feed') prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
outs = [] outs = []
...@@ -242,9 +239,8 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -242,9 +239,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
outs.append(out) outs.append(out)
# cache previous states (no-op except during incremental generation) # cache previous states (no-op except during incremental generation)
self.set_incremental_state('prev_hiddens', prev_hiddens) utils.set_incremental_state(
self.set_incremental_state('prev_cells', prev_cells) self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed))
self.set_incremental_state('input_feed', input_feed)
# collect outputs across time steps # collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim) x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
...@@ -263,34 +259,25 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -263,34 +259,25 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return x, attn_scores return x, attn_scores
def reorder_incremental_state(self, new_order): def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation).""" cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
super().reorder_incremental_state(new_order) if cached_state is None:
new_order = Variable(new_order) return
def reorder_state(key): def reorder_state(state):
old = self.get_incremental_state(key) if isinstance(state, list):
if isinstance(old, list): return [reorder_state(state_i) for state_i in state]
new = [old_i.index_select(0, new_order) for old_i in old] return state.index_select(0, new_order)
else:
new = old.index_select(0, new_order)
self.set_incremental_state(key, new)
reorder_state('prev_hiddens') if not isinstance(new_order, Variable):
reorder_state('prev_cells') new_order = Variable(new_order)
reorder_state('input_feed') new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
def _init_prev_states(self, encoder_out):
_, encoder_hiddens, encoder_cells = encoder_out
num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
return prev_hiddens, prev_cells
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......
...@@ -20,14 +20,10 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -20,14 +20,10 @@ class LearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx) super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad self.left_pad = left_pad
self._is_incremental_eval = False
def incremental_eval(self, mode=True): def forward(self, input, incremental_state=None):
self._is_incremental_eval = mode
def forward(self, input):
"""Input is expected to be of size [bsz x seqlen].""" """Input is expected to be of size [bsz x seqlen]."""
if self._is_incremental_eval: if incremental_state is not None:
# positions is the same for every token when decoding a single step # positions is the same for every token when decoding a single step
positions = Variable( positions = Variable(
input.data.new(1, 1).fill_(self.padding_idx + input.size(1))) input.data.new(1, 1).fill_(self.padding_idx + input.size(1)))
......
...@@ -22,35 +22,20 @@ class LinearizedConvolution(ConvTBC): ...@@ -22,35 +22,20 @@ class LinearizedConvolution(ConvTBC):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs) super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self._is_incremental_eval = False
self._linearized_weight = None self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight) self.register_backward_hook(self._clear_linearized_weight)
def remove_future_timesteps(self, x): def forward(self, input, incremental_state=None):
"""Remove future time steps created by padding.""" """
if not self._is_incremental_eval and self.kernel_size[0] > 1 and self.padding[0] > 0: Input: Time x Batch x Channel.
x = x[:-self.padding[0], :, :] Args:
return x incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
def incremental_eval(self, mode=True): between time steps, call reorder_incremental_state.
self._is_incremental_eval = mode """
if mode: if incremental_state is None:
self.clear_incremental_state()
def forward(self, input):
if self._is_incremental_eval:
return self.incremental_forward(input)
else:
return super().forward(input) return super().forward(input)
def incremental_forward(self, input):
"""Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and accepts
a single frame as input. If the input order changes between time steps,
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
"""
# reshape weight # reshape weight
weight = self._get_linearized_weight() weight = self._get_linearized_weight()
kw = self.kernel_size[0] kw = self.kernel_size[0]
...@@ -58,25 +43,37 @@ class LinearizedConvolution(ConvTBC): ...@@ -58,25 +43,37 @@ class LinearizedConvolution(ConvTBC):
bsz = input.size(0) # input: bsz x len x dim bsz = input.size(0) # input: bsz x len x dim
if kw > 1: if kw > 1:
input = input.data input = input.data
if self.input_buffer is None: input_buffer = self._get_input_buffer(incremental_state)
self.input_buffer = input.new(bsz, kw, input.size(2)) if input_buffer is None:
self.input_buffer.zero_() input_buffer = input.new(bsz, kw, input.size(2)).zero_()
self._set_input_buffer(incremental_state, input_buffer)
else: else:
# shift buffer # shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input # append next input
self.input_buffer[:, -1, :] = input[:, -1, :] input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(self.input_buffer) input = utils.volatile_variable(input_buffer)
with utils.maybe_no_grad(): with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
def clear_incremental_state(self): def remove_future_timesteps(self, x):
self.input_buffer = None """Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def reorder_incremental_state(self, new_order): def _set_input_buffer(self, incremental_state, new_buffer):
if self.input_buffer is not None: return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
self.input_buffer = self.input_buffer.index_select(0, new_order)
def _get_linearized_weight(self): def _get_linearized_weight(self):
if self._linearized_weight is None: if self._linearized_weight is None:
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
import argparse import argparse
import torch
from fairseq.criterions import CRITERION_REGISTRY from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY from fairseq.optim import OPTIMIZER_REGISTRY
...@@ -117,8 +119,9 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -117,8 +119,9 @@ def add_dataset_args(parser, train=False, gen=False):
def add_distributed_training_args(parser): def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training') group = parser.add_argument_group('Distributed training')
group.add_argument('--distributed-world-size', default=1, type=int, metavar='N', group.add_argument('--distributed-world-size', type=int, metavar='N',
help='total number of GPUs across all nodes, default: 1 GPU') default=torch.cuda.device_count(),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int, group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker') help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str, group.add_argument('--distributed-backend', default='nccl', type=str,
...@@ -223,6 +226,8 @@ def add_generation_args(parser): ...@@ -223,6 +226,8 @@ def add_generation_args(parser):
help='only print final scores') help='only print final scores')
group.add_argument('--score-reference', action='store_true', group.add_argument('--score-reference', action='store_true',
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
return group return group
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from contextlib import ExitStack
import math import math
import torch import torch
...@@ -51,7 +50,7 @@ class SequenceGenerator(object): ...@@ -51,7 +50,7 @@ class SequenceGenerator(object):
return self return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None): cuda=False, timer=None, prefix_size=0):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
...@@ -75,6 +74,7 @@ class SequenceGenerator(object): ...@@ -75,6 +74,7 @@ class SequenceGenerator(object):
input['src_lengths'], input['src_lengths'],
beam_size=beam_size, beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b), maxlen=int(maxlen_a*srclen + maxlen_b),
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
) )
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
...@@ -84,15 +84,12 @@ class SequenceGenerator(object): ...@@ -84,15 +84,12 @@ class SequenceGenerator(object):
ref = utils.strip_pad(s['target'].data[i, :], self.pad) ref = utils.strip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with ExitStack() as stack: with utils.maybe_no_grad():
for model in self.models: return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, src_lengths, beam_size, maxlen)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size() bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -101,11 +98,14 @@ class SequenceGenerator(object): ...@@ -101,11 +98,14 @@ class SequenceGenerator(object):
beam_size = min(beam_size, self.vocab_size - 1) beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = [] encoder_outs = []
incremental_states = {}
for model in self.models: for model in self.models:
if not self.retain_dropout: if not self.retain_dropout:
model.eval() model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size) incremental_states[model] = {}
else:
incremental_states[model] = None
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_out = model.encoder( encoder_out = model.encoder(
...@@ -243,9 +243,11 @@ class SequenceGenerator(object): ...@@ -243,9 +243,11 @@ class SequenceGenerator(object):
if reorder_state is not None: if reorder_state is not None:
for model in self.models: for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(reorder_state) model.decoder.reorder_incremental_state(
incremental_states[model], reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs) probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states)
if step == 0: if step == 0:
# at the first step all hypotheses are equally likely, so use # at the first step all hypotheses are equally likely, so use
# only the first beam # only the first beam
...@@ -267,15 +269,24 @@ class SequenceGenerator(object): ...@@ -267,15 +269,24 @@ class SequenceGenerator(object):
eos_bbsz_idx = buffer('eos_bbsz_idx') eos_bbsz_idx = buffer('eos_bbsz_idx')
eos_scores = buffer('eos_scores', type_of=scores) eos_scores = buffer('eos_scores', type_of=scores)
if step < maxlen: if step < maxlen:
# take the best 2 x beam_size predictions. We'll choose the first if prefix_tokens is not None and step < prefix_tokens.size(1):
# beam_size of these which don't predict eos to continue with. probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
torch.topk( cand_scores = probs_slice.gather(
probs.view(bsz, -1), dim=1,
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad index=prefix_tokens[:, step].view(-1, 1).data
out=(cand_scores, cand_indices), ).expand(-1, cand_size)
) cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
torch.div(cand_indices, self.vocab_size, out=cand_beams) cand_beams.resize_as_(cand_indices).fill_(0)
cand_indices.fmod_(self.vocab_size) else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch.topk(
probs.view(bsz, -1),
k=min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices),
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
else: else:
# finalize all active hypotheses once we hit maxlen # finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now # pick the hypothesis with the highest prob of EOS right now
...@@ -391,7 +402,7 @@ class SequenceGenerator(object): ...@@ -391,7 +402,7 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs): def _decode(self, tokens, encoder_outs, incremental_states):
# wrap in Variable # wrap in Variable
tokens = utils.volatile_variable(tokens) tokens = utils.volatile_variable(tokens)
...@@ -399,7 +410,7 @@ class SequenceGenerator(object): ...@@ -399,7 +410,7 @@ class SequenceGenerator(object):
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad(): with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out) decoder_out, attn = model.decoder(tokens, encoder_out, incremental_states[model])
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
......
...@@ -47,13 +47,13 @@ class Trainer(object): ...@@ -47,13 +47,13 @@ class Trainer(object):
self.meters['train_nll_loss'] = AverageMeter() self.meters['train_nll_loss'] = AverageMeter()
self.meters['valid_loss'] = AverageMeter() self.meters['valid_loss'] = AverageMeter()
self.meters['valid_nll_loss'] = AverageMeter() self.meters['valid_nll_loss'] = AverageMeter()
self.meters['wps'] = TimeMeter() # words per second self.meters['wps'] = TimeMeter() # words per second
self.meters['ups'] = TimeMeter() # updates per second self.meters['ups'] = TimeMeter() # updates per second
self.meters['wpb'] = AverageMeter() # words per batch self.meters['wpb'] = AverageMeter() # words per batch
self.meters['bsz'] = AverageMeter() # sentences per batch self.meters['bsz'] = AverageMeter() # sentences per batch
self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters['oom'] = AverageMeter() # out of memory
self._max_bsz_seen = 0 self._max_bsz_seen = 0
self._num_updates = 0 self._num_updates = 0
...@@ -190,7 +190,7 @@ class Trainer(object): ...@@ -190,7 +190,7 @@ class Trainer(object):
# clip grads # clip grads
if self.args.clip_norm > 0: if self.args.clip_norm > 0:
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm) grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm))
else: else:
grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters())) grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters()))
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict
import contextlib import contextlib
import logging import logging
import os import os
...@@ -198,6 +199,36 @@ def make_variable(sample, volatile=False, cuda=False): ...@@ -198,6 +199,36 @@ def make_variable(sample, volatile=False, cuda=False):
return _make_variable(sample) return _make_variable(sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_fairseq_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
def load_align_dict(replace_unk): def load_align_dict(replace_unk):
if replace_unk is None: if replace_unk is None:
align_dict = None align_dict = None
...@@ -273,3 +304,10 @@ def convert_padding_direction( ...@@ -273,3 +304,10 @@ def convert_padding_direction(
else: else:
index = torch.remainder(range + num_pads, max_len) index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index) return src_tokens.gather(1, index)
def item(tensor):
if hasattr(tensor, 'item'):
return tensor.item()
if hasattr(tensor, '__getitem__'):
return tensor[0]
return tensor
...@@ -18,8 +18,6 @@ def main(args): ...@@ -18,8 +18,6 @@ def main(args):
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset # Load dataset
if args.replace_unk is None: if args.replace_unk is None:
...@@ -92,7 +90,7 @@ def main(args): ...@@ -92,7 +90,7 @@ def main(args):
else: else:
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer) cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations: for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth # Process input and ground truth
......
...@@ -18,8 +18,6 @@ def main(args): ...@@ -18,8 +18,6 @@ def main(args):
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
......
...@@ -39,6 +39,7 @@ def get_parser(): ...@@ -39,6 +39,7 @@ def get_parser():
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
return parser return parser
def main(args): def main(args):
print(args) print(args)
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
......
...@@ -177,6 +177,7 @@ def get_training_stats(trainer): ...@@ -177,6 +177,7 @@ def get_training_stats(trainer):
stats['lr'] = trainer.get_lr() stats['lr'] = trainer.get_lr()
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg) stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg) stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg
return stats return stats
......
...@@ -16,7 +16,10 @@ import torch ...@@ -16,7 +16,10 @@ import torch
from fairseq import options from fairseq import options
import preprocess, train, generate, interactive import preprocess
import train
import generate
import interactive
class TestBinaries(unittest.TestCase): class TestBinaries(unittest.TestCase):
...@@ -80,6 +83,7 @@ class TestBinaries(unittest.TestCase): ...@@ -80,6 +83,7 @@ class TestBinaries(unittest.TestCase):
'--save-dir', data_dir, '--save-dir', data_dir,
'--max-epoch', '1', '--max-epoch', '1',
'--no-progress-bar', '--no-progress-bar',
'--distributed-world-size', '1',
], ],
) )
train.main(train_args) train.main(train_args)
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import data, dictionary from fairseq import data, dictionary, utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
...@@ -96,24 +96,21 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -96,24 +96,21 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100) args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args self.args = args
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if self._is_incremental_eval: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, prev_output_tokens, encoder_out):
bbsz = prev_output_tokens.size(0) bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary) vocab = len(self.dictionary)
src_len = encoder_out.size(1) src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1) tgt_len = prev_output_tokens.size(1)
# determine number of steps # determine number of steps
if self._is_incremental_eval: if incremental_state is not None:
# cache step number # cache step number
step = self.get_incremental_state('step') step = utils.get_incremental_state(self, incremental_state, 'step')
if step is None: if step is None:
step = 0 step = 0
self.set_incremental_state('step', step + 1) utils.set_incremental_state(self, incremental_state, 'step', step + 1)
steps = [step] steps = [step]
else: else:
steps = list(range(tgt_len)) steps = list(range(tgt_len))
......
...@@ -6,23 +6,23 @@ ...@@ -6,23 +6,23 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
from fairseq import options from fairseq import options
from distributed_train import main as distributed_main from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main from singleprocess_train import main as singleprocess_main
def main(args): def main(args):
if args.distributed_port > 0 \ if args.distributed_port > 0 \
or args.distributed_init_method is not None: or args.distributed_init_method is not None:
distributed_main(args) distributed_main(args)
elif torch.cuda.device_count() > 1: elif args.distributed_world_size > 1:
multiprocessing_main(args) multiprocessing_main(args)
else: else:
singleprocess_main(args) singleprocess_main(args)
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_training_parser() parser = options.get_training_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment