"vscode:/vscode.git/clone" did not exist on "ec55d7da3ba5c8bc59a63f2b5812dbcbf15fdef8"
Unverified Commit 388c520b authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

0.4.0 -> 0.5.0

Changelog:
- 97b58b46: add Transformer model from Vaswani et al. (2017)
- b2374e52: faster Transformer inference with improved caching
- 2d27ae08: simulate large mini-batch training with delayed updates (`--update-freq`)
- 7ee1d284: add FP16 training support (`--fp16`)
- 2a84f46b: faster inference by removing completed sentences from the batch
- 663fd806: batched interactive generation
- 4c2ef2de: add language modeling / gated convolutional model from Dauphin et al. (2017)
- b59815bc: add Hierarchical Neural Story Generation model from Fan et al. (2018)
- ff68a9ef: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
parents ec0031df 5383b5db
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
...@@ -29,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -29,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else: else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(Variable(positions)) return super().forward(positions)
def max_positions(self): def max_positions(self):
"""Maximum number of supported positions.""" """Maximum number of supported positions."""
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
...@@ -59,8 +60,8 @@ class LinearizedConvolution(ConvTBC): ...@@ -59,8 +60,8 @@ class LinearizedConvolution(ConvTBC):
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input # append next input
input_buffer[:, -1, :] = input[:, -1, :] input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(input_buffer) input = input_buffer
with utils.maybe_no_grad(): with torch.no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from fairseq import utils
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim
self.scaling = self.head_dim**-0.5
self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if qkv_same:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
else:
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
else:
attn_weights = None
return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim)
def _in_proj(self, input, start=None, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
if end is not None:
weight = weight[:end, :]
if bias is not None:
bias = bias[:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class ScalarBias(torch.autograd.Function):
"""
Adds a vector of scalars, used in self-attention mechanism to allow
the model to optionally attend to this vector instead of the past
"""
@staticmethod
def forward(ctx, input, dim, bias_init):
size = list(input.size())
size[dim] += 1
output = input.new(*size).fill_(bias_init)
output.narrow(dim, 1, size[dim] - 1).copy_(input)
ctx.dim = dim
return output
@staticmethod
def backward(ctx, grad):
return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
def scalar_bias(input, dim, bias_init=0):
return ScalarBias.apply(input, dim, bias_init)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
from fairseq import utils
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor())
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
).type_as(self.weights)
self.weights = self.weights.type_as(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1)
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler from . import FairseqLRScheduler, register_lr_scheduler
...@@ -16,16 +14,22 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -16,16 +14,22 @@ class FixedSchedule(FairseqLRScheduler):
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
super().__init__(args, optimizer) super().__init__(args, optimizer)
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer.optimizer, self.anneal) self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add arguments to the parser for this LR scheduler.""" """Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch') help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
def anneal(self, epoch): def get_next_lr(self, epoch):
lrs = self.args.lr lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal: if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule # use fixed LR schedule
...@@ -33,10 +37,18 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -33,10 +37,18 @@ class FixedSchedule(FairseqLRScheduler):
else: else:
# annneal based on lr_shrink # annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal) next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR return next_lr
def step(self, epoch, val_loss=None): def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch.""" """Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss) super().step(epoch, val_loss)
self.lr_scheduler.step(epoch) self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr() return self.optimizer.get_lr()
...@@ -13,10 +13,11 @@ from fairseq.criterions import CRITERION_REGISTRY ...@@ -13,10 +13,11 @@ from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
def get_training_parser(): def get_training_parser(default_task='translation'):
parser = get_parser('Trainer') parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True) add_dataset_args(parser, train=True)
add_distributed_training_args(parser) add_distributed_training_args(parser)
add_model_args(parser) add_model_args(parser)
...@@ -25,13 +26,42 @@ def get_training_parser(): ...@@ -25,13 +26,42 @@ def get_training_parser():
return parser return parser
def get_generation_parser(): def get_generation_parser(interactive=False, default_task='translation'):
parser = get_parser('Generation') parser = get_parser('Generation', default_task)
add_dataset_args(parser, gen=True) add_dataset_args(parser, gen=True)
add_generation_args(parser) add_generation_args(parser)
if interactive:
add_interactive_args(parser)
return parser return parser
def get_eval_lm_parser(default_task='language_modeling'):
parser = get_parser('Evaluate Language Model', default_task)
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def eval_str_list(x, type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(map(type, x))
except TypeError:
return [type(x)]
def eval_bool(x, default=False):
if x is None:
return default
try:
return bool(eval(x))
except TypeError:
return default
def parse_args_and_arch(parser, input_args=None): def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
...@@ -40,34 +70,44 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -40,34 +70,44 @@ def parse_args_and_arch(parser, input_args=None):
args, _ = parser.parse_known_args(input_args) args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser. # Add model-specific args to parser.
model_specific_group = parser.add_argument_group( if hasattr(args, 'arch'):
'Model-specific configuration', model_specific_group = parser.add_argument_group(
# Only include attributes which are explicitly given as command-line 'Model-specific configuration',
# arguments or which have default values. # Only include attributes which are explicitly given as command-line
argument_default=argparse.SUPPRESS, # arguments or which have default values.
) argument_default=argparse.SUPPRESS,
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) )
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser. # Add *-specific args to parser.
CRITERION_REGISTRY[args.criterion].add_args(parser) if hasattr(args, 'criterion'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser) CRITERION_REGISTRY[args.criterion].add_args(parser)
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser) if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'task'):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time. # Parse a second time.
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
# Post-process args. # Post-process args.
args.lr = list(map(float, args.lr.split(','))) if hasattr(args, 'lr'):
if args.max_sentences_valid is None: args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences args.max_sentences_valid = args.max_sentences
# Apply architecture configuration. # Apply architecture configuration.
ARCH_CONFIG_REGISTRY[args.arch](args) if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
return args return args
def get_parser(desc): def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc) description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
...@@ -77,24 +117,21 @@ def get_parser(desc): ...@@ -77,24 +117,21 @@ def get_parser(desc):
choices=['json', 'none', 'simple', 'tqdm']) choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
# Task definitions can be found under fairseq/tasks/
parser.add_argument(
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(),
help='task: {} (default: {})'.format(', '.join(TASK_REGISTRY.keys()), default_task)
)
return parser return parser
def add_dataset_args(parser, train=False, gen=False): def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading') group = parser.add_argument_group('Dataset and data loading')
group.add_argument('data', metavar='DIR',
help='path to data directory')
group.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set') help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', default=6000, type=int, metavar='N', group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N', group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
...@@ -104,7 +141,7 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -104,7 +141,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='data subset to use for training (train, valid, test)') help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT', group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation' help='comma separated list of data subsets to use for validation'
' (train, valid, valid1,test, test1)') ' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N', group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch' help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)') ' (defaults to --max-sentences)')
...@@ -148,6 +185,10 @@ def add_optimization_args(parser): ...@@ -148,6 +185,10 @@ def add_optimization_args(parser):
group.add_argument('--sentence-avg', action='store_true', group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch' help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)') ' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i')
group.add_argument('--fp16', action='store_true',
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/ # Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
...@@ -170,12 +211,6 @@ def add_optimization_args(parser): ...@@ -170,12 +211,6 @@ def add_optimization_args(parser):
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate') help='minimum learning rate')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
return group return group
...@@ -185,10 +220,14 @@ def add_checkpoint_args(parser): ...@@ -185,10 +220,14 @@ def add_checkpoint_args(parser):
help='path to save checkpoints') help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt', group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint') help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=-1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N updates') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates')
group.add_argument('--no-save', action='store_true', group.add_argument('--no-save', action='store_true',
help='don\'t save models and checkpoints') help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true', group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints') help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N', group.add_argument('--validate-interval', type=int, default=1, metavar='N',
...@@ -196,10 +235,24 @@ def add_checkpoint_args(parser): ...@@ -196,10 +235,24 @@ def add_checkpoint_args(parser):
return group return group
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
def add_generation_args(parser): def add_generation_args(parser):
group = parser.add_argument_group('Generation') group = parser.add_argument_group('Generation')
group.add_argument('--path', metavar='FILE', action='append', add_common_eval_args(group)
help='path(s) to model file(s)')
group.add_argument('--beam', default=5, type=int, metavar='N', group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size') help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N', group.add_argument('--nbest', default=1, type=int, metavar='N',
...@@ -210,15 +263,14 @@ def add_generation_args(parser): ...@@ -210,15 +263,14 @@ def add_generation_args(parser):
group.add_argument('--max-len-b', default=200, type=int, metavar='N', group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, ' help=('generate sequences of maximum length ax + b, '
'where x is the source length')) 'where x is the source length'))
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--min-len', default=1, type=float, metavar='N',
help='remove BPE tokens before scoring') help=('minimum generation length'))
group.add_argument('--no-early-stop', action='store_true', group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam ' help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases ' 'hypotheses; this is more correct, but increases '
'generation time by 50%%')) 'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true', group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores') help='compare unnormalized hypothesis scores')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--no-beamable-mm', action='store_true', group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers') help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float, group.add_argument('--lenpen', default=1, type=float,
...@@ -227,17 +279,25 @@ def add_generation_args(parser): ...@@ -227,17 +279,25 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer') help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None, group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)') help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--score-reference', action='store_true', group.add_argument('--score-reference', action='store_true',
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS', group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length')) help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true', group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search') help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
return group return group
def add_interactive_args(parser):
group = parser.add_argument_group('Interactive')
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
def add_model_args(parser): def add_model_args(parser):
group = parser.add_argument_group('Model configuration') group = parser.add_argument_group('Model configuration')
......
...@@ -117,8 +117,9 @@ class json_progress_bar(progress_bar): ...@@ -117,8 +117,9 @@ class json_progress_bar(progress_bar):
def print(self, stats): def print(self, stats):
"""Print end-of-epoch stats.""" """Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch) stats = self._format_stats(self.stats, epoch=self.epoch)
print("sweep_log: " + json.dumps(stats), flush=True) print(json.dumps(stats), flush=True)
def _format_stats(self, stats, epoch=None, update=None): def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict() postfix = OrderedDict()
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import torch import torch
from fairseq import utils from fairseq import utils
...@@ -13,11 +14,12 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -13,11 +14,12 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None, def __init__(
stop_early=True, normalize_scores=True, len_penalty=1, self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
unk_penalty=0, retain_dropout=False, sampling=False): normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1,
):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
min/maxlen: The length of the generated output will be bounded by min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker). minlen and maxlen (not including the end-of-sentence marker).
...@@ -27,13 +29,10 @@ class SequenceGenerator(object): ...@@ -27,13 +29,10 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output. normalize_scores: Normalize scores by the length of the output.
""" """
self.models = models self.models = models
self.pad = models[0].dst_dict.pad() self.pad = tgt_dict.pad()
self.unk = models[0].dst_dict.unk() self.unk = tgt_dict.unk()
self.eos = models[0].dst_dict.eos() self.eos = tgt_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:]) self.vocab_size = len(tgt_dict)
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models) max_decoder_len = min(m.max_decoder_positions() for m in self.models)
...@@ -45,16 +44,19 @@ class SequenceGenerator(object): ...@@ -45,16 +44,19 @@ class SequenceGenerator(object):
self.unk_penalty = unk_penalty self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout self.retain_dropout = retain_dropout
self.sampling = sampling self.sampling = sampling
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
model.cuda() model.cuda()
return self return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None, def generate_batched_itr(
cuda=False, timer=None, prefix_size=0): self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
maxlen_a/b: generate sequences of maximum length ax + b, maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length. where x is the source sentence length.
...@@ -65,12 +67,14 @@ class SequenceGenerator(object): ...@@ -65,12 +67,14 @@ class SequenceGenerator(object):
maxlen_b = self.maxlen maxlen_b = self.maxlen
for sample in data_itr: for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda) s = utils.move_to_cuda(sample) if cuda else sample
if 'net_input' not in s:
continue
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
with utils.maybe_no_grad(): with torch.no_grad():
hypos = self.generate( hypos = self.generate(
input['src_tokens'], input['src_tokens'],
input['src_lengths'], input['src_lengths'],
...@@ -81,14 +85,14 @@ class SequenceGenerator(object): ...@@ -81,14 +85,14 @@ class SequenceGenerator(object):
if timer is not None: if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos)) timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data): for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :] # remove padding
# remove padding from ref src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with utils.maybe_no_grad(): with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens) return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
...@@ -112,7 +116,7 @@ class SequenceGenerator(object): ...@@ -112,7 +116,7 @@ class SequenceGenerator(object):
# compute the encoder output for each beam # compute the encoder output for each beam
encoder_out = model.encoder( encoder_out = model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen), src_tokens.repeat(1, beam_size).view(-1, srclen),
src_lengths.repeat(beam_size), src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
) )
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
...@@ -135,11 +139,12 @@ class SequenceGenerator(object): ...@@ -135,11 +139,12 @@ class SequenceGenerator(object):
cand_size = 2 * beam_size # 2 x beam size in case half are EOS cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes # offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens) bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens) cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly # helper function for allocating buffers on the fly
buffers = {} buffers = {}
def buffer(name, type_of=tokens): # noqa def buffer(name, type_of=tokens): # noqa
if name not in buffers: if name not in buffers:
buffers[name] = type_of.new() buffers[name] = type_of.new()
...@@ -159,7 +164,7 @@ class SequenceGenerator(object): ...@@ -159,7 +164,7 @@ class SequenceGenerator(object):
# finalized one # finalized one
best_unfinalized_score = unfinalized_scores[sent].max() best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores: if self.normalize_scores:
best_unfinalized_score /= maxlen best_unfinalized_score /= maxlen ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score: if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True return True
return False return False
...@@ -168,11 +173,9 @@ class SequenceGenerator(object): ...@@ -168,11 +173,9 @@ class SequenceGenerator(object):
""" """
Finalize the given hypotheses at this step, while keeping the total Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size. number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those hypotheses that appear earlier in the input are preferred to those
that appear later. that appear later.
Args: Args:
step: current time step step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size), bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
...@@ -186,7 +189,7 @@ class SequenceGenerator(object): ...@@ -186,7 +189,7 @@ class SequenceGenerator(object):
# clone relevant token and attention tensors # clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx) tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step+2] # skip the first index, which is EOS tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
...@@ -198,19 +201,34 @@ class SequenceGenerator(object): ...@@ -198,19 +201,34 @@ class SequenceGenerator(object):
# normalize sentence-level scores # normalize sentence-level scores
if self.normalize_scores: if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty eos_scores /= (step + 1) ** self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set() sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())): for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
sent = idx // beam_size unfin_idx = idx // beam_size
sents_seen.add(sent) sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
def get_hypo(): def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
_, alignment = hypo_attn.max(dim=0)
return { return {
'tokens': tokens_clone[i], 'tokens': tokens_clone[i],
'score': score, 'score': score,
'attention': attn_clone[i], # src_len x tgt_len 'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment, 'alignment': alignment,
'positional_scores': pos_scores[i], 'positional_scores': pos_scores[i],
} }
...@@ -230,26 +248,30 @@ class SequenceGenerator(object): ...@@ -230,26 +248,30 @@ class SequenceGenerator(object):
'idx': idx, 'idx': idx,
} }
# return number of hypotheses finished this step newly_finished = []
num_finished = 0 for sent, unfin_idx in sents_seen:
for sent in sents_seen:
# check termination conditions for this sentence # check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores): if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True finished[sent] = True
num_finished += 1 newly_finished.append(unfin_idx)
return num_finished return newly_finished
reorder_state = None reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams # reorder decoder internal states based on the prev choice of beams
if reorder_state is not None: if reorder_state is not None:
for model in self.models: if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state( model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
incremental_states[model], reorder_state) encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states) tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0: if step == 0:
# at the first step all hypotheses are equally likely, so use # at the first step all hypotheses are equally likely, so use
# only the first beam # only the first beam
...@@ -258,13 +280,13 @@ class SequenceGenerator(object): ...@@ -258,13 +280,13 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(probs) scores_buf = scores_buf.type_as(probs)
elif not self.sampling: elif not self.sampling:
# make probs contain cumulative scores for each hypothesis # make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1)) probs.add_(scores[:, step - 1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores # Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores) attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores) cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices') cand_indices = buffer('cand_indices')
...@@ -282,15 +304,29 @@ class SequenceGenerator(object): ...@@ -282,15 +304,29 @@ class SequenceGenerator(object):
cand_beams.resize_as_(cand_indices).fill_(0) cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling: elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored' assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
exp_probs = probs.exp_().view(-1, self.vocab_size)
if step == 0: if self.sampling_topk > 0:
# we exclude the first two vocab items, one of which is pad values, indices = probs[:, 2:].topk(self.sampling_topk)
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices) exp_probs = values.div_(self.sampling_temperature).exp()
if step == 0:
torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
cand_indices.add_(2) cand_indices.add_(2)
else: else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices) exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2) cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores) torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_() cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2) cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2) cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
...@@ -301,7 +337,7 @@ class SequenceGenerator(object): ...@@ -301,7 +337,7 @@ class SequenceGenerator(object):
# make scores cumulative # make scores cumulative
cand_scores.add_( cand_scores.add_(
torch.gather( torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1, scores[:, step - 1].view(bsz, beam_size), dim=1,
index=cand_beams, index=cand_beams,
) )
) )
...@@ -323,18 +359,20 @@ class SequenceGenerator(object): ...@@ -323,18 +359,20 @@ class SequenceGenerator(object):
descending=True, descending=True,
out=(eos_scores, eos_bbsz_idx), out=(eos_scores, eos_bbsz_idx),
) )
num_remaining_sent -= finalize_hypos( num_remaining_sent -= len(finalize_hypos(
step, eos_bbsz_idx, eos_scores) step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0 assert num_remaining_sent == 0
break break
# cand_bbsz_idx contains beam indices for the top candidate # cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size), # hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size] # and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add_(bbsz_offsets) cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos # finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos) eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen: if step >= self.minlen:
# only consider eos when it's among the top beam_size indices # only consider eos when it's among the top beam_size indices
torch.masked_select( torch.masked_select(
...@@ -348,20 +386,49 @@ class SequenceGenerator(object): ...@@ -348,20 +386,49 @@ class SequenceGenerator(object):
mask=eos_mask[:, :beam_size], mask=eos_mask[:, :beam_size],
out=eos_scores, out=eos_scores,
) )
num_remaining_sent -= finalize_hypos( finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores) step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0 assert num_remaining_sent >= 0
if num_remaining_sent == 0: if num_remaining_sent == 0:
break break
assert step < maxlen assert step < maxlen
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos # set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos. # and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos # After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask') active_mask = buffer('active_mask')
torch.add( torch.add(
eos_mask.type_as(cand_offsets)*cand_size, eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)], cand_offsets[:eos_mask.size(1)],
out=active_mask, out=active_mask,
) )
...@@ -382,17 +449,18 @@ class SequenceGenerator(object): ...@@ -382,17 +449,18 @@ class SequenceGenerator(object):
cand_scores, dim=1, index=active_hypos, cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size), out=scores[:, step].view(bsz, beam_size),
) )
active_bbsz_idx = active_bbsz_idx.view(-1) active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1) active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses # copy tokens and scores for active hypotheses
torch.index_select( torch.index_select(
tokens[:, :step+1], dim=0, index=active_bbsz_idx, tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1], out=tokens_buf[:, :step + 1],
) )
torch.gather( torch.gather(
cand_indices, dim=1, index=active_hypos, cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1], out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
) )
if step > 0: if step > 0:
torch.index_select( torch.index_select(
...@@ -406,51 +474,37 @@ class SequenceGenerator(object): ...@@ -406,51 +474,37 @@ class SequenceGenerator(object):
# copy attention for active hypotheses # copy attention for active hypotheses
torch.index_select( torch.index_select(
attn[:, :, :step+2], dim=0, index=active_bbsz_idx, attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2], out=attn_buf[:, :, :step + 2],
) )
# swap buffers # swap buffers
old_tokens = tokens tokens, tokens_buf = tokens_buf, tokens
tokens = tokens_buf scores, scores_buf = scores_buf, scores
tokens_buf = old_tokens attn, attn_buf = attn_buf, attn
old_scores = scores
scores = scores_buf
scores_buf = old_scores
old_attn = attn
attn = attn_buf
attn_buf = old_attn
# reorder incremental state in decoder # reorder incremental state in decoder
reorder_state = active_bbsz_idx reorder_state = active_bbsz_idx
# sort by score descending # sort by score descending
for sent in range(bsz): for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized return finalized
def _decode(self, tokens, encoder_outs, incremental_states): def _decode(self, tokens, encoder_outs, incremental_states):
# wrap in Variable if len(self.models) == 1:
tokens = utils.volatile_variable(tokens) return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad(): probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
avg_probs.add_(probs) avg_probs.add_(probs)
if attn is not None: if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None: if avg_attn is None:
avg_attn = attn avg_attn = attn
else: else:
...@@ -459,5 +513,17 @@ class SequenceGenerator(object): ...@@ -459,5 +513,17 @@ class SequenceGenerator(object):
avg_probs.log_() avg_probs.log_()
if avg_attn is not None: if avg_attn is not None:
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn
...@@ -5,16 +5,17 @@ ...@@ -5,16 +5,17 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
from fairseq import utils from fairseq import utils
class SequenceScorer(object): class SequenceScorer(object):
"""Scores the target for a given source sentence.""" """Scores the target for a given source sentence."""
def __init__(self, models): def __init__(self, models, tgt_dict):
self.models = models self.models = models
self.pad = models[0].dst_dict.pad() self.pad = tgt_dict.pad()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
...@@ -24,21 +25,22 @@ class SequenceScorer(object): ...@@ -24,21 +25,22 @@ class SequenceScorer(object):
def score_batched_itr(self, data_itr, cuda=False, timer=None): def score_batched_itr(self, data_itr, cuda=False, timer=None):
"""Iterate over a batched dataset and yield scored translations.""" """Iterate over a batched dataset and yield scored translations."""
for sample in data_itr: for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda) s = utils.move_to_cuda(sample) if cuda else sample
if timer is not None: if timer is not None:
timer.start() timer.start()
pos_scores, attn = self.score(s) pos_scores, attn = self.score(s)
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data): for i, id in enumerate(s['id'].data):
src = s['net_input']['src_tokens'].data[i, :]
# remove padding from ref # remove padding from ref
ref = utils.strip_pad(s['target'].data[i, :], self.pad) src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
tgt_len = ref.numel() tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len] pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len score_i = pos_scores_i.sum() / tgt_len
attn_i = attn[i] if attn is not None:
_, alignment = attn_i.max(dim=0) attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{ hypos = [{
'tokens': ref, 'tokens': ref,
'score': score_i, 'score': score_i,
...@@ -46,6 +48,8 @@ class SequenceScorer(object): ...@@ -46,6 +48,8 @@ class SequenceScorer(object):
'alignment': alignment, 'alignment': alignment,
'positional_scores': pos_scores_i, 'positional_scores': pos_scores_i,
}] }]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator # return results in the same format as SequenceGenerator
yield id, src, ref, hypos yield id, src, ref, hypos
...@@ -57,18 +61,12 @@ class SequenceScorer(object): ...@@ -57,18 +61,12 @@ class SequenceScorer(object):
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model in self.models: for model in self.models:
with utils.maybe_no_grad(): with torch.no_grad():
model.eval() model.eval()
encoder_out = model.encoder( decoder_out = model.forward(**net_input)
net_input['src_tokens'],
net_input['src_lengths'],
)
decoder_out = model.decoder(
net_input['prev_output_tokens'],
encoder_out,
)
attn = decoder_out[1] attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_task import FairseqTask
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args):
return TASK_REGISTRY[args.task].setup_task(args)
def register_task(name):
"""Decorator to register a new task."""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError('Cannot register duplicate task ({})'.format(name))
if not issubclass(cls, FairseqTask):
raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
return cls
return register_task_cls
# automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.tasks.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object):
"""
A Task defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for building the model/criterion and calculating the loss.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
pass
def __init__(self, args):
self.args = args
self.datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split):
raise NotImplementedError
def dataset(self, split):
"""Return a dataset split."""
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split]
def build_model(self, args):
return models.build_model(args, self)
def build_criterion(self, args):
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
return criterion(model, sample)
@property
def source_dictionary(self):
raise NotImplementedError
@property
def target_dictionary(self):
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
)
from . import FairseqTask, register_task
@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split):
"""Load a dataset split."""
path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path)
tokens = ds.buffer
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
return self.dictionary
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset,
)
from . import FairseqTask, register_task
@register_task('translation')
class TranslationTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split):
"""Load a dataset split."""
def split_exists(src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path)
return None
src_dataset = indexed_dataset(prefix + src, self.src_dict)
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
...@@ -10,8 +10,6 @@ import re ...@@ -10,8 +10,6 @@ import re
import torch import torch
from fairseq import dictionary
SPACE_NORMALIZER = re.compile("\s+") SPACE_NORMALIZER = re.compile("\s+")
...@@ -24,13 +22,6 @@ def tokenize_line(line): ...@@ -24,13 +22,6 @@ def tokenize_line(line):
class Tokenizer: class Tokenizer:
@staticmethod
def build_dictionary(filename, tokenize=tokenize_line):
dict = dictionary.Dictionary()
Tokenizer.add_file_to_dictionary(filename, dict, tokenize)
dict.finalize()
return dict
@staticmethod @staticmethod
def add_file_to_dictionary(filename, dict, tokenize): def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f: with open(filename, 'r') as f:
......
...@@ -6,11 +6,13 @@ ...@@ -6,11 +6,13 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
""" """
Train a network on multiple GPUs. Train a network across multiple GPUs.
""" """
from collections import OrderedDict from collections import defaultdict, OrderedDict
import math import contextlib
from itertools import chain
import torch import torch
from fairseq import distributed_utils, optim, utils from fairseq import distributed_utils, optim, utils
...@@ -19,14 +21,14 @@ from fairseq.optim import lr_scheduler ...@@ -19,14 +21,14 @@ from fairseq.optim import lr_scheduler
class Trainer(object): class Trainer(object):
"""Main class for multi-GPU training. """Main class for data parallel training.
Each GPU has a full copy of the model and is assigned to its own Python This class supports data parallel training, where multiple workers each
process. Gradients are accumulated with torch.distributed.all_reduce and all have a full model replica and gradients are accumulated synchronously via
model replicas are updated synchronously after each batch. torch.distributed.all_reduce.
""" """
def __init__(self, args, model, criterion): def __init__(self, args, task, model, criterion):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
...@@ -34,12 +36,11 @@ class Trainer(object): ...@@ -34,12 +36,11 @@ class Trainer(object):
self.args = args self.args = args
# copy model and criterion to current device # copy model and criterion to current device
self.task = task
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler self.optimizer = None
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
# initialize meters # initialize meters
self.meters = OrderedDict() self.meters = OrderedDict()
...@@ -54,25 +55,34 @@ class Trainer(object): ...@@ -54,25 +55,34 @@ class Trainer(object):
self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds
self._max_bsz_seen = 0 self._buffered_stats = defaultdict(lambda: [])
self._flat_grads = None
self._num_updates = 0 self._num_updates = 0
self._optim_history = None
def _build_optimizer(self):
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer, extra_state['train_meters'] = self.meters
self.lr_scheduler, self._num_updates, self._optim_history, extra_state) utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename): def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state( extra_state, self._optim_history, last_optim_state = \
filename, self.model, cuda_device=torch.cuda.current_device()) utils.load_model_state(filename, self.model)
if last_optim_state is not None: if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._build_optimizer()
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
...@@ -83,42 +93,89 @@ class Trainer(object): ...@@ -83,42 +93,89 @@ class Trainer(object):
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
if 'train_meters' in extra_state:
self.meters = extra_state['train_meters']
del extra_state['train_meters']
return extra_state return extra_state
def train_step(self, sample): def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
sample = self._prepare_sample(sample, volatile=False) if self.optimizer is None:
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
# forward pass self._build_optimizer()
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample)
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# forward and backward pass
sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params:
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
# aggregate stats and logging outputs try:
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) # all-reduce and rescale gradients, then take an optimization step
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) grad_norm = self._all_reduce_and_rescale(grad_denom)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) self._opt()
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# update meters
# backward pass, all-reduce gradients and take an optimization step self.meters['wps'].update(ntokens)
grad_norm, ooms_bwd = self._backward_and_opt(loss, grad_denom) self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
# update meters self.meters['bsz'].update(nsentences)
self.meters['wps'].update(ntokens) if grad_norm is not None:
self.meters['ups'].update(1.) self.meters['gnorm'].update(grad_norm)
self.meters['wpb'].update(ntokens) self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['bsz'].update(nsentences) self.meters['oom'].update(ooms_fwd + ooms_bwd)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.) # update loss meters for training
self.meters['oom'].update(ooms_fwd + ooms_bwd) if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# update loss meters for training # criterions can optionally log the NLL loss too
if 'loss' in agg_logging_output: if 'nll_loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom) self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
# criterions can optionally log the NLL loss too except OverflowError as e:
if 'nll_loss' in agg_logging_output: self.zero_grad()
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) print('| WARNING: overflow detected, ' + str(e))
return agg_logging_output self.clear_buffered_stats()
return agg_logging_output
else:
return None # buffering updates
def _forward(self, sample, eval=False): def _forward(self, sample, eval=False):
# prepare model and optimizer # prepare model and optimizer
...@@ -126,8 +183,6 @@ class Trainer(object): ...@@ -126,8 +183,6 @@ class Trainer(object):
self.model.eval() self.model.eval()
else: else:
self.model.train() self.model.train()
self.optimizer.zero_grad()
loss = None loss = None
sample_size = 0 sample_size = 0
logging_output = { logging_output = {
...@@ -137,33 +192,20 @@ class Trainer(object): ...@@ -137,33 +192,20 @@ class Trainer(object):
oom = 0 oom = 0
if sample is not None: if sample is not None:
try: try:
with utils.maybe_no_grad(eval): with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size # calculate loss and sample size
loss, sample_size, logging_output_ = self.criterion(self.model, sample) loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_) logging_output.update(logging_output_)
except RuntimeError as e: except RuntimeError as e:
if not eval and 'out of memory' in str(e): if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch') print('| WARNING: ran out of memory, skipping batch')
oom = 1 oom = 1
loss = None loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else: else:
raise e raise e
return loss, sample_size, logging_output, oom
# synchronize logging outputs for multi-GPU training def _backward(self, loss):
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms = zip(*list(
distributed_utils.all_gather_list((sample_size, logging_output, oom))))
ooms = sum(ooms)
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
ooms = oom
return loss, sample_sizes, logging_outputs, ooms
def _backward_and_opt(self, loss, grad_denom):
oom = 0 oom = 0
if loss is not None: if loss is not None:
try: try:
...@@ -173,44 +215,81 @@ class Trainer(object): ...@@ -173,44 +215,81 @@ class Trainer(object):
if 'out of memory' in str(e): if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch') print('| WARNING: ran out of memory, skipping batch')
oom = 1 oom = 1
if hasattr(torch.cuda, 'empty_cache'): self.zero_grad()
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else: else:
raise e raise e
return oom
# all-reduce grads and rescale by grad_denom def _all_reduce_and_rescale(self, grad_denom):
# flatten grads into a single buffer and all-reduce
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad] torch.distributed.all_reduce(flat_grads)
distributed_utils.all_reduce_and_rescale_tensors(grads, grad_denom)
else: # rescale and clip gradients
for p in self.model.parameters(): flat_grads.div_(grad_denom)
if p.requires_grad: grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
p.grad.data.div_(grad_denom)
# copy grads back into model parameters
# clip grads self._set_flat_grads(flat_grads)
if self.args.clip_norm > 0:
grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)) return grad_norm
else:
grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters())) def _get_grads(self):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False')
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
grads = self._get_grads()
if out is None:
grads_size = sum(g.numel() for g in grads)
out = grads[0].new(grads_size).zero_()
offset = 0
for g in grads:
numel = g.numel()
out[offset:offset+numel].copy_(g.view(-1))
offset += numel
return out[:offset]
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
self.zero_grad()
self._num_updates += 1 self._num_updates += 1
# update learning rate # update learning rate
self.lr_scheduler.step_update(self._num_updates) self.lr_scheduler.step_update(self._num_updates)
return grad_norm, oom
def valid_step(self, sample): def valid_step(self, sample):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
sample = self._prepare_sample(sample, volatile=True)
# forward pass # forward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample, eval=True) sample = self._prepare_sample(sample)
assert not ooms_fwd, 'Ran out of memory during validation' _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output)
))
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
# aggregate stats and logging outputs # aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
...@@ -226,10 +305,26 @@ class Trainer(object): ...@@ -226,10 +305,26 @@ class Trainer(object):
return agg_logging_output return agg_logging_output
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None): def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss.""" """Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss) return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self): def get_lr(self):
"""Get the current learning rate.""" """Get the current learning rate."""
return self.optimizer.get_lr() return self.optimizer.get_lr()
...@@ -248,12 +343,7 @@ class Trainer(object): ...@@ -248,12 +343,7 @@ class Trainer(object):
"""Get the number of parameters updates.""" """Get the number of parameters updates."""
return self._num_updates return self._num_updates
def _prepare_sample(self, sample, volatile): def _prepare_sample(self, sample):
if sample is None or len(sample) == 0: if sample is None or len(sample) == 0:
return None return None
if hasattr(torch.cuda, 'empty_cache'): return utils.move_to_cuda(sample)
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
return utils.make_variable(sample, volatile=volatile, cuda=True)
...@@ -5,14 +5,13 @@ ...@@ -5,14 +5,13 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict from collections import defaultdict, OrderedDict
import contextlib
import logging import logging
import os import os
import re
import torch import torch
import traceback import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
...@@ -25,6 +24,20 @@ def torch_persistent_save(*args, **kwargs): ...@@ -25,6 +24,20 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model, criterion, optimizer, lr_scheduler, def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None): num_updates, optim_history=None, extra_state=None):
if optim_history is None: if optim_history is None:
...@@ -33,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -33,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {} extra_state = {}
state_dict = { state_dict = {
'args': args, 'args': args,
'model': model.state_dict(), 'model': convert_state_dict_type(model.state_dict()),
'optimizer_history': optim_history + [ 'optimizer_history': optim_history + [
{ {
'criterion_name': criterion.__class__.__name__, 'criterion_name': criterion.__class__.__name__,
...@@ -42,28 +55,22 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -42,28 +55,22 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
'num_updates': num_updates, 'num_updates': num_updates,
} }
], ],
'last_optimizer_state': optimizer.state_dict(), 'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state, 'extra_state': extra_state,
} }
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
def load_model_state(filename, model, cuda_device=None): def load_model_state(filename, model):
if not os.path.exists(filename): if not os.path.exists(filename):
return None, [], None return None, [], None
if cuda_device is None: state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model']) model.upgrade_state_dict(state['model'])
# load model parameters # load model parameters
try: try:
model.load_state_dict(state['model']) model.load_state_dict(state['model'], strict=True)
except Exception: except Exception:
raise Exception('Cannot load model parameters from checkpoint, ' raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match') 'please ensure that the architectures match')
...@@ -111,43 +118,44 @@ def _upgrade_state_dict(state): ...@@ -111,43 +118,44 @@ def _upgrade_state_dict(state):
# keep track of number of updates # keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]: if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0 state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': 0,
}
return state return state
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
data_dir=None, model_arg_overrides=None):
"""Load an ensemble of models for inference. """Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
model_arg_overrides allows you to pass a dictionary model_arg_overrides -- model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model {'arg_name': arg} -- to override model args that were used during model
training training
""" """
from fairseq import data, models
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
if not os.path.exists(filename): if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename)) raise IOError('Model file not found: {}'.format(filename))
states.append( state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) state = _upgrade_state_dict(state)
) states.append(state)
args = states[0]['args'] args = states[0]['args']
if model_arg_overrides is not None: if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) args = _override_model_args(args, model_arg_overrides)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble # build ensemble
ensemble = [] ensemble = []
for state in states: for state in states:
model = models.build_model(args, src_dict, dst_dict) model = task.build_model(args)
model.load_state_dict(state['model']) model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model) ensemble.append(model)
return ensemble, args return ensemble, args
...@@ -159,46 +167,24 @@ def _override_model_args(args, model_arg_overrides): ...@@ -159,46 +167,24 @@ def _override_model_args(args, model_arg_overrides):
return args return args
def maybe_no_grad(condition=True): def move_to_cuda(sample):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
# volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda=False):
"""Wrap input tensors in Variable class."""
if len(sample) == 0: if len(sample) == 0:
return {} return {}
def _make_variable(maybe_tensor): def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
if cuda and torch.cuda.is_available(): return maybe_tensor.cuda()
maybe_tensor = maybe_tensor.cuda()
if volatile:
return volatile_variable(maybe_tensor)
else:
return Variable(maybe_tensor)
elif isinstance(maybe_tensor, dict): elif isinstance(maybe_tensor, dict):
return { return {
key: _make_variable(value) key: _move_to_cuda(value)
for key, value in maybe_tensor.items() for key, value in maybe_tensor.items()
} }
elif isinstance(maybe_tensor, list): elif isinstance(maybe_tensor, list):
return [_make_variable(x) for x in maybe_tensor] return [_move_to_cuda(x) for x in maybe_tensor]
else: else:
return maybe_tensor return maybe_tensor
return _make_variable(sample) return _move_to_cuda(sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
...@@ -268,7 +254,7 @@ def parse_embedding(embed_path): ...@@ -268,7 +254,7 @@ def parse_embedding(embed_path):
""" """
embed_dict = {} embed_dict = {}
with open(embed_path) as f_embed: with open(embed_path) as f_embed:
_ = next(f_embed) # skip header next(f_embed) # skip header
for line in f_embed: for line in f_embed:
pieces = line.strip().split() pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
...@@ -297,15 +283,15 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk): ...@@ -297,15 +283,15 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
return ' '.join(hypo_tokens) return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe): def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
from fairseq import tokenizer from fairseq import tokenizer
hypo_str = dst_dict.string(hypo_tokens, remove_bpe) hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None: if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string()) hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
if align_dict is not None or remove_bpe is not None: if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE # Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method. # Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True) hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment return hypo_tokens, hypo_str, alignment
...@@ -342,12 +328,7 @@ def buffered_arange(max): ...@@ -342,12 +328,7 @@ def buffered_arange(max):
return buffered_arange.buf[:max] return buffered_arange.buf[:max]
def convert_padding_direction( def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
src_tokens,
padding_idx,
right_to_left=False,
left_to_right=False,
):
assert right_to_left ^ left_to_right assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx) pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any(): if not pad_mask.any():
...@@ -375,3 +356,35 @@ def item(tensor): ...@@ -375,3 +356,35 @@ def item(tensor):
if hasattr(tensor, '__getitem__'): if hasattr(tensor, '__getitem__'):
return tensor[0] return tensor[0]
return tensor return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import torch import torch
from fairseq import bleu, data, options, progress_bar, tokenizer, utils from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
...@@ -16,76 +16,67 @@ from fairseq.sequence_scorer import SequenceScorer ...@@ -16,76 +16,67 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args): def main(args):
assert args.path is not None, '--path required for generation!' assert args.path is not None, '--path required for generation!'
print(args)
assert not args.sampling or args.nbest == args.beam, \ assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam' '--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dataset splits
if args.replace_unk is None: task = tasks.setup_task(args)
dataset = data.load_dataset( task.load_dataset(args.gen_subset)
args.data, print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
[args.gen_subset],
args.source_lang,
args.target_lang,
)
else:
dataset = data.load_raw_text_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble # Set dictionaries
print('| loading model(s) from {}'.format(', '.join(args.path))) src_dict = task.source_dictionary
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict) tgt_dict = task.target_dictionary
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) # Load ensemble
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| loading model(s) from {}'.format(args.path))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset]))) models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_( model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
)
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded) # Load dataset (possibly sharded)
max_positions = min(model.max_encoder_positions() for model in models) itr = data.EpochBatchIterator(
itr = dataset.eval_dataloader( dataset=task.dataset(args.gen_subset),
args.gen_subset, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=models[0].max_positions(),
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
) required_batch_size_multiple=8,
if args.num_shards > 1: num_shards=args.num_shards,
if args.shard_id < 0 or args.shard_id >= args.num_shards: shard_id=args.shard_id,
raise ValueError('--shard-id must be between 0 and num_shards') ).next_epoch_itr(shuffle=False)
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
# Initialize generator # Initialize generator
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
if args.score_reference: if args.score_reference:
translator = SequenceScorer(models) translator = SequenceScorer(models, task.target_dictionary)
else: else:
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, task.target_dictionary, beam_size=args.beam,
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
unk_penalty=args.unkpen, sampling=args.sampling) len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0 num_sentences = 0
has_target = True has_target = True
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
...@@ -94,21 +85,23 @@ def main(args): ...@@ -94,21 +85,23 @@ def main(args):
else: else:
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size) cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations: for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth # Process input and ground truth
has_target = target_tokens is not None has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens. # Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None: if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id) src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else: else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) src_str = src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, if has_target:
args.remove_bpe, target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
escape_unk=True) if has_target else ''
if not args.quiet: if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str)) print('S-{}\t{}'.format(sample_id, src_str))
...@@ -122,7 +115,7 @@ def main(args): ...@@ -122,7 +115,7 @@ def main(args):
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict, align_dict=align_dict,
dst_dict=dataset.dst_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
...@@ -145,20 +138,20 @@ def main(args): ...@@ -145,20 +138,20 @@ def main(args):
if align_dict is not None or args.remove_bpe is not None: if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize( target_tokens = tokenizer.Tokenizer.tokenize(
target_str, dataset.dst_dict, add_if_not_exist=True) target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens) scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
num_sentences += 1 num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target: if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser() parser = options.get_generation_parser()
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
...@@ -6,30 +6,79 @@ ...@@ -6,30 +6,79 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import namedtuple
import numpy as np
import sys import sys
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import options, tokenizer, utils from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')
def buffered_read(buffer_size):
buffer = []
for src_str in sys.stdin:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(lines, args, src_dict, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
for src_str in lines
]
lengths = np.array([t.numel() for t in tokens])
itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
srcs=[lines[i] for i in batch['id']],
tokens=batch['net_input']['src_tokens'],
lengths=batch['net_input']['src_lengths'],
), batch['id']
def main(args): def main(args):
print(args) if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
args.max_sentences = 1
assert not args.sampling or args.nbest == args.beam, \ assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam' '--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences, \ assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size is not supported in interactive mode' '--max-sentences/--batch-size cannot be larger than --buffer-size'
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Setup task, e.g., translation
task = tasks.setup_task(args)
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(args.path))
models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data) model_paths = args.path.split(':')
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict models, model_args = utils.load_ensemble_for_inference(model_paths, task)
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict))) # Set dictionaries
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict))) src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
...@@ -39,9 +88,12 @@ def main(args): ...@@ -39,9 +88,12 @@ def main(args):
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen) unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len,
)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
...@@ -49,19 +101,12 @@ def main(args): ...@@ -49,19 +101,12 @@ def main(args):
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
print('| Type the input sentence and press return:') def make_result(src_str, hypos):
for src_str in sys.stdin: result = Translation(
src_str = src_str.strip() src_str='O\t{}'.format(src_str),
src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() hypos=[],
if use_cuda: alignments=[],
src_tokens = src_tokens.cuda()
src_lengths = src_tokens.new([src_tokens.numel()])
translations = translator.generate(
Variable(src_tokens.view(1, -1)),
Variable(src_lengths.view(-1)),
) )
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions # Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]: for hypo in hypos[:min(len(hypos), args.nbest)]:
...@@ -70,14 +115,48 @@ def main(args): ...@@ -70,14 +115,48 @@ def main(args):
src_str=src_str, src_str=src_str,
alignment=hypo['alignment'].int().cpu(), alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict, align_dict=align_dict,
dst_dict=dst_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
print('H\t{}\t{}'.format(hypo['score'], hypo_str)) result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))) result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
return result
def process_batch(batch):
tokens = batch.tokens
lengths = batch.lengths
if use_cuda:
tokens = tokens.cuda()
lengths = lengths.cuda()
translations = translator.generate(
Variable(tokens),
Variable(lengths),
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
indices.extend(batch_indices)
results += process_batch(batch)
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments):
print(hypo)
print(align)
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_generation_parser() parser = options.get_generation_parser(interactive=True)
args = parser.parse_args() args = options.parse_args_and_arch(parser)
main(args) main(args)
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
from fairseq import distributed_utils, options from fairseq import distributed_utils, options
from singleprocess_train import main as single_process_main from train import main as single_process_main
def main(args): def main(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment