Unverified Commit 388c520b authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

0.4.0 -> 0.5.0

Changelog:
- 97b58b46: add Transformer model from Vaswani et al. (2017)
- b2374e52: faster Transformer inference with improved caching
- 2d27ae08: simulate large mini-batch training with delayed updates (`--update-freq`)
- 7ee1d284: add FP16 training support (`--fp16`)
- 2a84f46b: faster inference by removing completed sentences from the batch
- 663fd806: batched interactive generation
- 4c2ef2de: add language modeling / gated convolutional model from Dauphin et al. (2017)
- b59815bc: add Hierarchical Neural Story Generation model from Fan et al. (2018)
- ff68a9ef: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
parents ec0031df 5383b5db
......@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.autograd import Variable
import torch.nn as nn
from fairseq import utils
......@@ -29,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(Variable(positions))
return super().forward(positions)
def max_positions(self):
"""Maximum number of supported positions."""
......
......@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
from fairseq import utils
......@@ -59,8 +60,8 @@ class LinearizedConvolution(ConvTBC):
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input
input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(input_buffer)
with utils.maybe_no_grad():
input = input_buffer
with torch.no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from fairseq import utils
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim
self.scaling = self.head_dim**-0.5
self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if 'prev_key' in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert kv_same and not qkv_same
key = value = None
else:
saved_state = None
if qkv_same:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self.in_proj_q(query)
if key is None:
assert value is None
# this will allow us to concat it with previous value and get
# just get the previous value
k = v = q.new(0)
else:
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q *= self.scaling
if saved_state is not None:
if 'prev_key' in saved_state:
k = torch.cat((saved_state['prev_key'], k), dim=0)
if 'prev_value' in saved_state:
v = torch.cat((saved_state['prev_value'], v), dim=0)
saved_state['prev_key'] = k
saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(0)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None)
if mask_future_timesteps and incremental_state is None:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.float().masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
# average attention weights over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.sum(dim=1) / self.num_heads
else:
attn_weights = None
return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim)
def _in_proj(self, input, start=None, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
if end is not None:
weight = weight[:end, :]
if bias is not None:
bias = bias[:end]
if start is not None:
weight = weight[start:, :]
if bias is not None:
bias = bias[start:]
return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(1, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
utils.set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class ScalarBias(torch.autograd.Function):
"""
Adds a vector of scalars, used in self-attention mechanism to allow
the model to optionally attend to this vector instead of the past
"""
@staticmethod
def forward(ctx, input, dim, bias_init):
size = list(input.size())
size[dim] += 1
output = input.new(*size).fill_(bias_init)
output.narrow(dim, 1, size[dim] - 1).copy_(input)
ctx.dim = dim
return output
@staticmethod
def backward(ctx, grad):
return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
def scalar_bias(input, dim, bias_init=0):
return ScalarBias.apply(input, dim, bias_init)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
from fairseq import utils
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor())
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
if max_pos > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
).type_as(self.weights)
self.weights = self.weights.type_as(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1)
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
......@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim.lr_scheduler
from . import FairseqLRScheduler, register_lr_scheduler
......@@ -16,16 +14,22 @@ class FixedSchedule(FairseqLRScheduler):
def __init__(self, args, optimizer):
super().__init__(args, optimizer)
self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer.optimizer, self.anneal)
self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
def anneal(self, epoch):
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
......@@ -33,10 +37,18 @@ class FixedSchedule(FairseqLRScheduler):
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (epoch + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
return next_lr
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
self.lr_scheduler.step(epoch)
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates:
self.warmup_factor = num_updates / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
......@@ -13,10 +13,11 @@ from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
from fairseq.optim import OPTIMIZER_REGISTRY
from fairseq.optim.lr_scheduler import LR_SCHEDULER_REGISTRY
from fairseq.tasks import TASK_REGISTRY
def get_training_parser():
parser = get_parser('Trainer')
def get_training_parser(default_task='translation'):
parser = get_parser('Trainer', default_task)
add_dataset_args(parser, train=True)
add_distributed_training_args(parser)
add_model_args(parser)
......@@ -25,13 +26,42 @@ def get_training_parser():
return parser
def get_generation_parser():
parser = get_parser('Generation')
def get_generation_parser(interactive=False, default_task='translation'):
parser = get_parser('Generation', default_task)
add_dataset_args(parser, gen=True)
add_generation_args(parser)
if interactive:
add_interactive_args(parser)
return parser
def get_eval_lm_parser(default_task='language_modeling'):
parser = get_parser('Evaluate Language Model', default_task)
add_dataset_args(parser, gen=True)
add_eval_lm_args(parser)
return parser
def eval_str_list(x, type=float):
if x is None:
return None
if isinstance(x, str):
x = eval(x)
try:
return list(map(type, x))
except TypeError:
return [type(x)]
def eval_bool(x, default=False):
if x is None:
return default
try:
return bool(eval(x))
except TypeError:
return default
def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
......@@ -40,34 +70,44 @@ def parse_args_and_arch(parser, input_args=None):
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
if hasattr(args, 'arch'):
model_specific_group = parser.add_argument_group(
'Model-specific configuration',
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
# Add *-specific args to parser.
CRITERION_REGISTRY[args.criterion].add_args(parser)
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'criterion'):
CRITERION_REGISTRY[args.criterion].add_args(parser)
if hasattr(args, 'optimizer'):
OPTIMIZER_REGISTRY[args.optimizer].add_args(parser)
if hasattr(args, 'lr_scheduler'):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
if hasattr(args, 'task'):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time.
args = parser.parse_args(input_args)
# Post-process args.
args.lr = list(map(float, args.lr.split(',')))
if args.max_sentences_valid is None:
if hasattr(args, 'lr'):
args.lr = eval_str_list(args.lr, type=float)
if hasattr(args, 'update_freq'):
args.update_freq = eval_str_list(args.update_freq, type=int)
if hasattr(args, 'max_sentences_valid') and args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
# Apply architecture configuration.
ARCH_CONFIG_REGISTRY[args.arch](args)
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
return args
def get_parser(desc):
def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser(
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
......@@ -77,24 +117,21 @@ def get_parser(desc):
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
# Task definitions can be found under fairseq/tasks/
parser.add_argument(
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(),
help='task: {} (default: {})'.format(', '.join(TASK_REGISTRY.keys()), default_task)
)
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
group.add_argument('data', metavar='DIR',
help='path to data directory')
group.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
group.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='ignore too long or too short lines in valid and test set')
group.add_argument('--max-tokens', type=int, metavar='N',
help='maximum number of tokens in a batch')
group.add_argument('--max-sentences', '--batch-size', type=int, metavar='N',
help='maximum number of sentences in a batch')
......@@ -104,7 +141,7 @@ def add_dataset_args(parser, train=False, gen=False):
help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation'
' (train, valid, valid1,test, test1)')
' (train, valid, valid1, test, test1)')
group.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
......@@ -148,6 +185,10 @@ def add_optimization_args(parser):
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i')
group.add_argument('--fp16', action='store_true',
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
......@@ -170,12 +211,6 @@ def add_optimization_args(parser):
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
help='minimum learning rate')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly without replacement from the'
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
return group
......@@ -185,10 +220,14 @@ def add_checkpoint_args(parser):
help='path to save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename in save-dir from which to load checkpoint')
group.add_argument('--save-interval', type=int, default=-1, metavar='N',
help='save a checkpoint every N updates')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep last N checkpoints saved with --save-interval-updates')
group.add_argument('--no-save', action='store_true',
help='don\'t save models and checkpoints')
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
......@@ -196,10 +235,24 @@ def add_checkpoint_args(parser):
return group
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
group.add_argument('--path', metavar='FILE', action='append',
help='path(s) to model file(s)')
add_common_eval_args(group)
group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N',
......@@ -210,15 +263,14 @@ def add_generation_args(parser):
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'))
group.add_argument('--unnormalized', action='store_true',
help='compare unnormalized hypothesis scores')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
group.add_argument('--no-beamable-mm', action='store_true',
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
......@@ -227,17 +279,25 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
return group
def add_interactive_args(parser):
group = parser.add_argument_group('Interactive')
group.add_argument('--buffer-size', default=0, type=int, metavar='N',
help='read this many sentences into a buffer before processing them')
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
......
......@@ -117,8 +117,9 @@ class json_progress_bar(progress_bar):
def print(self, stats):
"""Print end-of-epoch stats."""
self.stats = stats
stats = self._format_stats(self.stats, epoch=self.epoch)
print("sweep_log: " + json.dumps(stats), flush=True)
print(json.dumps(stats), flush=True)
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import math
import torch
from fairseq import utils
......@@ -13,11 +14,12 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0, retain_dropout=False, sampling=False):
def __init__(
self, models, tgt_dict, beam_size=1, minlen=1, maxlen=None, stop_early=True,
normalize_scores=True, len_penalty=1, unk_penalty=0, retain_dropout=False,
sampling=False, sampling_topk=-1, sampling_temperature=1,
):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
......@@ -27,13 +29,10 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output.
"""
self.models = models
self.pad = models[0].dst_dict.pad()
self.unk = models[0].dst_dict.unk()
self.eos = models[0].dst_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos()
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = min(m.max_decoder_positions() for m in self.models)
......@@ -45,16 +44,19 @@ class SequenceGenerator(object):
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def cuda(self):
for model in self.models:
model.cuda()
return self
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0):
def generate_batched_itr(
self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
......@@ -65,12 +67,14 @@ class SequenceGenerator(object):
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda)
s = utils.move_to_cuda(sample) if cuda else sample
if 'net_input' not in s:
continue
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
with utils.maybe_no_grad():
with torch.no_grad():
hypos = self.generate(
input['src_tokens'],
input['src_lengths'],
......@@ -81,14 +85,14 @@ class SequenceGenerator(object):
if timer is not None:
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :]
# remove padding from ref
# remove padding
src = utils.strip_pad(input['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with utils.maybe_no_grad():
with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
......@@ -112,7 +116,7 @@ class SequenceGenerator(object):
# compute the encoder output for each beam
encoder_out = model.encoder(
src_tokens.repeat(1, beam_size).view(-1, srclen),
src_lengths.repeat(beam_size),
src_lengths.expand(beam_size, src_lengths.numel()).t().contiguous().view(-1),
)
encoder_outs.append(encoder_out)
......@@ -135,11 +139,12 @@ class SequenceGenerator(object):
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz)*beam_size).unsqueeze(1).type_as(tokens)
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
# helper function for allocating buffers on the fly
buffers = {}
def buffer(name, type_of=tokens): # noqa
if name not in buffers:
buffers[name] = type_of.new()
......@@ -159,7 +164,7 @@ class SequenceGenerator(object):
# finalized one
best_unfinalized_score = unfinalized_scores[sent].max()
if self.normalize_scores:
best_unfinalized_score /= maxlen
best_unfinalized_score /= maxlen ** self.len_penalty
if worst_finalized[sent]['score'] >= best_unfinalized_score:
return True
return False
......@@ -168,11 +173,9 @@ class SequenceGenerator(object):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
......@@ -186,7 +189,7 @@ class SequenceGenerator(object):
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[:, 1:step+2] # skip the first index, which is EOS
tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2]
......@@ -198,19 +201,34 @@ class SequenceGenerator(object):
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step+1)**self.len_penalty
eos_scores /= (step + 1) ** self.len_penalty
cum_unfin = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
sents_seen = set()
for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
sent = idx // beam_size
sents_seen.add(sent)
unfin_idx = idx // beam_size
sent = unfin_idx + cum_unfin[unfin_idx]
sents_seen.add((sent, unfin_idx))
def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
_, alignment = hypo_attn.max(dim=0)
return {
'tokens': tokens_clone[i],
'score': score,
'attention': attn_clone[i], # src_len x tgt_len
'attention': hypo_attn, # src_len x tgt_len
'alignment': alignment,
'positional_scores': pos_scores[i],
}
......@@ -230,26 +248,30 @@ class SequenceGenerator(object):
'idx': idx,
}
# return number of hypotheses finished this step
num_finished = 0
for sent in sents_seen:
newly_finished = []
for sent, unfin_idx in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
num_finished += 1
return num_finished
newly_finished.append(unfin_idx)
return newly_finished
reorder_state = None
batch_idxs = None
for step in range(maxlen + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
for model in self.models:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(
incremental_states[model], reorder_state)
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states)
tokens[:, :step + 1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
......@@ -258,13 +280,13 @@ class SequenceGenerator(object):
scores_buf = scores_buf.type_as(probs)
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1))
probs.add_(scores[:, step - 1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores)
attn[:, :, step + 1].copy_(avg_attn_scores)
cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices')
......@@ -282,15 +304,29 @@ class SequenceGenerator(object):
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
exp_probs = probs.exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
if self.sampling_topk > 0:
values, indices = probs[:, 2:].topk(self.sampling_topk)
exp_probs = values.div_(self.sampling_temperature).exp()
if step == 0:
torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
cand_indices.add_(2)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
......@@ -301,7 +337,7 @@ class SequenceGenerator(object):
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1,
scores[:, step - 1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
......@@ -323,18 +359,20 @@ class SequenceGenerator(object):
descending=True,
out=(eos_scores, eos_bbsz_idx),
)
num_remaining_sent -= finalize_hypos(
step, eos_bbsz_idx, eos_scores)
num_remaining_sent -= len(finalize_hypos(
step, eos_bbsz_idx, eos_scores))
assert num_remaining_sent == 0
break
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add_(bbsz_offsets)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
eos_mask = cand_indices.eq(self.eos)
finalized_sents = set()
if step >= self.minlen:
# only consider eos when it's among the top beam_size indices
torch.masked_select(
......@@ -348,20 +386,49 @@ class SequenceGenerator(object):
mask=eos_mask[:, :beam_size],
out=eos_scores,
)
num_remaining_sent -= finalize_hypos(
finalized_sents = finalize_hypos(
step, eos_bbsz_idx, eos_scores, cand_scores)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
assert step < maxlen
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(bsz).type_as(cand_indices)
batch_mask[cand_indices.new(finalized_sents)] = 0
batch_idxs = batch_mask.nonzero().squeeze(-1)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
scores_buf.resize_as_(scores)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens_buf.resize_as_(tokens)
attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
attn_buf.resize_as_(attn)
bsz = new_bsz
else:
batch_idxs = None
# set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask')
torch.add(
eos_mask.type_as(cand_offsets)*cand_size,
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[:eos_mask.size(1)],
out=active_mask,
)
......@@ -382,17 +449,18 @@ class SequenceGenerator(object):
cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
torch.index_select(
tokens[:, :step+1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step+1],
tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
out=tokens_buf[:, :step + 1],
)
torch.gather(
cand_indices, dim=1, index=active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1],
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
)
if step > 0:
torch.index_select(
......@@ -406,51 +474,37 @@ class SequenceGenerator(object):
# copy attention for active hypotheses
torch.index_select(
attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step+2],
attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
out=attn_buf[:, :, :step + 2],
)
# swap buffers
old_tokens = tokens
tokens = tokens_buf
tokens_buf = old_tokens
old_scores = scores
scores = scores_buf
scores_buf = old_scores
old_attn = attn
attn = attn_buf
attn_buf = old_attn
tokens, tokens_buf = tokens_buf, tokens
scores, scores_buf = scores_buf, scores
attn, attn_buf = attn_buf, attn
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(bsz):
for sent in range(len(finalized)):
finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
return finalized
def _decode(self, tokens, encoder_outs, incremental_states):
# wrap in Variable
tokens = utils.volatile_variable(tokens)
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], incremental_states, log_probs=True)
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
probs, attn = self._decode_one(tokens, model, encoder_out, incremental_states, log_probs=False)
if avg_probs is None:
avg_probs = probs
else:
avg_probs.add_(probs)
if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None:
avg_attn = attn
else:
......@@ -459,5 +513,17 @@ class SequenceGenerator(object):
avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _decode_one(self, tokens, model, encoder_out, incremental_states, log_probs):
with torch.no_grad():
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
if attn is not None:
attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn
......@@ -5,16 +5,17 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import utils
class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(self, models):
def __init__(self, models, tgt_dict):
self.models = models
self.pad = models[0].dst_dict.pad()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
self.pad = tgt_dict.pad()
def cuda(self):
for model in self.models:
......@@ -24,21 +25,22 @@ class SequenceScorer(object):
def score_batched_itr(self, data_itr, cuda=False, timer=None):
"""Iterate over a batched dataset and yield scored translations."""
for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda)
s = utils.move_to_cuda(sample) if cuda else sample
if timer is not None:
timer.start()
pos_scores, attn = self.score(s)
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data):
src = s['net_input']['src_tokens'].data[i, :]
# remove padding from ref
ref = utils.strip_pad(s['target'].data[i, :], self.pad)
src = utils.strip_pad(s['net_input']['src_tokens'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
tgt_len = ref.numel()
pos_scores_i = pos_scores[i][:tgt_len]
score_i = pos_scores_i.sum() / tgt_len
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
if attn is not None:
attn_i = attn[i]
_, alignment = attn_i.max(dim=0)
else:
attn_i = alignment = None
hypos = [{
'tokens': ref,
'score': score_i,
......@@ -46,6 +48,8 @@ class SequenceScorer(object):
'alignment': alignment,
'positional_scores': pos_scores_i,
}]
if timer is not None:
timer.stop(s['ntokens'])
# return results in the same format as SequenceGenerator
yield id, src, ref, hypos
......@@ -57,18 +61,12 @@ class SequenceScorer(object):
avg_probs = None
avg_attn = None
for model in self.models:
with utils.maybe_no_grad():
with torch.no_grad():
model.eval()
encoder_out = model.encoder(
net_input['src_tokens'],
net_input['src_lengths'],
)
decoder_out = model.decoder(
net_input['prev_output_tokens'],
encoder_out,
)
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data
if avg_probs is None:
avg_probs = probs
else:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_task import FairseqTask
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args):
return TASK_REGISTRY[args.task].setup_task(args)
def register_task(name):
"""Decorator to register a new task."""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError('Cannot register duplicate task ({})'.format(name))
if not issubclass(cls, FairseqTask):
raise ValueError('Task ({}: {}) must extend FairseqTask'.format(name, cls.__name__))
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError('Cannot register task with duplicate class name ({})'.format(cls.__name__))
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
return cls
return register_task_cls
# automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.tasks.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object):
"""
A Task defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for building the model/criterion and calculating the loss.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
pass
def __init__(self, args):
self.args = args
self.datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split):
raise NotImplementedError
def dataset(self, split):
"""Return a dataset split."""
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split]
def build_model(self, args):
return models.build_model(args, self)
def build_criterion(self, args):
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
return criterion(model, sample)
@property
def source_dictionary(self):
raise NotImplementedError
@property
def target_dictionary(self):
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
)
from . import FairseqTask, register_task
@register_task('language_modeling')
class LanguageModelingTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--sample-break-mode', metavar='VAL',
choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. '
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N',
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
@classmethod
def setup_task(cls, args, **kwargs):
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split):
"""Load a dataset split."""
path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path)
tokens = ds.buffer
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
return self.dictionary
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset,
)
from . import FairseqTask, register_task
@register_task('translation')
class TranslationTask(FairseqTask):
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, src_dict, tgt_dict):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, args, **kwargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
# find language pair automatically
if args.source_lang is None or args.target_lang is None:
args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
if args.source_lang is None or args.target_lang is None:
raise Exception('Could not infer language pair, please provide it explicitly')
# load dictionaries
src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split):
"""Load a dataset split."""
def split_exists(src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path)
return None
src_dataset = indexed_dataset(prefix + src, self.src_dict)
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
......@@ -10,8 +10,6 @@ import re
import torch
from fairseq import dictionary
SPACE_NORMALIZER = re.compile("\s+")
......@@ -24,13 +22,6 @@ def tokenize_line(line):
class Tokenizer:
@staticmethod
def build_dictionary(filename, tokenize=tokenize_line):
dict = dictionary.Dictionary()
Tokenizer.add_file_to_dictionary(filename, dict, tokenize)
dict.finalize()
return dict
@staticmethod
def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f:
......
......@@ -6,11 +6,13 @@
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
Train a network across multiple GPUs.
"""
from collections import OrderedDict
import math
from collections import defaultdict, OrderedDict
import contextlib
from itertools import chain
import torch
from fairseq import distributed_utils, optim, utils
......@@ -19,14 +21,14 @@ from fairseq.optim import lr_scheduler
class Trainer(object):
"""Main class for multi-GPU training.
"""Main class for data parallel training.
Each GPU has a full copy of the model and is assigned to its own Python
process. Gradients are accumulated with torch.distributed.all_reduce and all
model replicas are updated synchronously after each batch.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def __init__(self, args, model, criterion):
def __init__(self, args, task, model, criterion):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
......@@ -34,12 +36,11 @@ class Trainer(object):
self.args = args
# copy model and criterion to current device
self.task = task
self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self.optimizer = None
# initialize meters
self.meters = OrderedDict()
......@@ -54,25 +55,34 @@ class Trainer(object):
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds
self._max_bsz_seen = 0
self._buffered_stats = defaultdict(lambda: [])
self._flat_grads = None
self._num_updates = 0
self._optim_history = None
def _build_optimizer(self):
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
extra_state['train_meters'] = self.meters
utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, cuda_device=torch.cuda.current_device())
extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.model)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self._build_optimizer()
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
......@@ -83,42 +93,89 @@ class Trainer(object):
self._num_updates = last_optim['num_updates']
if 'train_meters' in extra_state:
self.meters = extra_state['train_meters']
del extra_state['train_meters']
return extra_state
def train_step(self, sample):
def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update."""
sample = self._prepare_sample(sample, volatile=False)
# forward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample)
if self.optimizer is None:
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self._build_optimizer()
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# forward and backward pass
sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params:
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# backward pass, all-reduce gradients and take an optimization step
grad_norm, ooms_bwd = self._backward_and_opt(loss, grad_denom)
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
return agg_logging_output
try:
# all-reduce and rescale gradients, then take an optimization step
grad_norm = self._all_reduce_and_rescale(grad_denom)
self._opt()
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))
self.clear_buffered_stats()
return agg_logging_output
else:
return None # buffering updates
def _forward(self, sample, eval=False):
# prepare model and optimizer
......@@ -126,8 +183,6 @@ class Trainer(object):
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
loss = None
sample_size = 0
logging_output = {
......@@ -137,33 +192,20 @@ class Trainer(object):
oom = 0
if sample is not None:
try:
with utils.maybe_no_grad(eval):
with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size
loss, sample_size, logging_output_ = self.criterion(self.model, sample)
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return loss, sample_size, logging_output, oom
# synchronize logging outputs for multi-GPU training
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms = zip(*list(
distributed_utils.all_gather_list((sample_size, logging_output, oom))))
ooms = sum(ooms)
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
ooms = oom
return loss, sample_sizes, logging_outputs, ooms
def _backward_and_opt(self, loss, grad_denom):
def _backward(self, loss):
oom = 0
if loss is not None:
try:
......@@ -173,44 +215,81 @@ class Trainer(object):
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
self.zero_grad()
else:
raise e
return oom
# all-reduce grads and rescale by grad_denom
def _all_reduce_and_rescale(self, grad_denom):
# flatten grads into a single buffer and all-reduce
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
if self.args.distributed_world_size > 1:
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
distributed_utils.all_reduce_and_rescale_tensors(grads, grad_denom)
else:
for p in self.model.parameters():
if p.requires_grad:
p.grad.data.div_(grad_denom)
# clip grads
if self.args.clip_norm > 0:
grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm))
else:
grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters()))
torch.distributed.all_reduce(flat_grads)
# rescale and clip gradients
flat_grads.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
# copy grads back into model parameters
self._set_flat_grads(flat_grads)
return grad_norm
def _get_grads(self):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False')
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
grads = self._get_grads()
if out is None:
grads_size = sum(g.numel() for g in grads)
out = grads[0].new(grads_size).zero_()
offset = 0
for g in grads:
numel = g.numel()
out[offset:offset+numel].copy_(g.view(-1))
offset += numel
return out[:offset]
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step
self.optimizer.step()
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
return grad_norm, oom
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
sample = self._prepare_sample(sample, volatile=True)
# forward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample, eval=True)
assert not ooms_fwd, 'Ran out of memory during validation'
sample = self._prepare_sample(sample)
_loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output)
))
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
......@@ -226,10 +305,26 @@ class Trainer(object):
return agg_logging_output
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
def lr_step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.lr_scheduler.step_update(num_updates)
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()
......@@ -248,12 +343,7 @@ class Trainer(object):
"""Get the number of parameters updates."""
return self._num_updates
def _prepare_sample(self, sample, volatile):
def _prepare_sample(self, sample):
if sample is None or len(sample) == 0:
return None
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
return utils.make_variable(sample, volatile=volatile, cuda=True)
return utils.move_to_cuda(sample)
......@@ -5,14 +5,13 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import defaultdict
import contextlib
from collections import defaultdict, OrderedDict
import logging
import os
import re
import torch
import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
......@@ -25,6 +24,20 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
......@@ -33,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {}
state_dict = {
'args': args,
'model': model.state_dict(),
'model': convert_state_dict_type(model.state_dict()),
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
......@@ -42,28 +55,22 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
'num_updates': num_updates,
}
],
'last_optimizer_state': optimizer.state_dict(),
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def load_model_state(filename, model, cuda_device=None):
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
if cuda_device is None:
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
model.upgrade_state_dict(state['model'])
# load model parameters
try:
model.load_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
except Exception:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
......@@ -111,43 +118,44 @@ def _upgrade_state_dict(state):
# keep track of number of updates
if 'num_updates' not in state['optimizer_history'][-1]:
state['optimizer_history'][-1]['num_updates'] = 0
# old model checkpoints may not have separate source/target positions
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'):
state['args'].max_source_positions = state['args'].max_positions
state['args'].max_target_positions = state['args'].max_positions
# use stateful training data iterator
if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': 0,
}
return state
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
data_dir=None, model_arg_overrides=None):
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
from fairseq import data, models
# load model architectures and weights
states = []
for filename in filenames:
if not os.path.exists(filename):
raise IOError('Model file not found: {}'.format(filename))
states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
args = states[0]['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []
for state in states:
model = models.build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
......@@ -159,46 +167,24 @@ def _override_model_args(args, model_arg_overrides):
return args
def maybe_no_grad(condition=True):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
# volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda=False):
"""Wrap input tensors in Variable class."""
def move_to_cuda(sample):
if len(sample) == 0:
return {}
def _make_variable(maybe_tensor):
def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor):
if cuda and torch.cuda.is_available():
maybe_tensor = maybe_tensor.cuda()
if volatile:
return volatile_variable(maybe_tensor)
else:
return Variable(maybe_tensor)
return maybe_tensor.cuda()
elif isinstance(maybe_tensor, dict):
return {
key: _make_variable(value)
key: _move_to_cuda(value)
for key, value in maybe_tensor.items()
}
elif isinstance(maybe_tensor, list):
return [_make_variable(x) for x in maybe_tensor]
return [_move_to_cuda(x) for x in maybe_tensor]
else:
return maybe_tensor
return _make_variable(sample)
return _move_to_cuda(sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
......@@ -268,7 +254,7 @@ def parse_embedding(embed_path):
"""
embed_dict = {}
with open(embed_path) as f_embed:
_ = next(f_embed) # skip header
next(f_embed) # skip header
for line in f_embed:
pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
......@@ -297,15 +283,15 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
from fairseq import tokenizer
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string())
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
......@@ -342,12 +328,7 @@ def buffered_arange(max):
return buffered_arange.buf[:max]
def convert_padding_direction(
src_tokens,
padding_idx,
right_to_left=False,
left_to_right=False,
):
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
......@@ -375,3 +356,35 @@ def item(tensor):
if hasattr(tensor, '__getitem__'):
return tensor[0]
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp = re.compile(pattern)
files = os.listdir(path)
entries = []
for i, f in enumerate(files):
m = pt_regexp.fullmatch(f)
if m is not None:
idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
......@@ -8,7 +8,7 @@
import torch
from fairseq import bleu, data, options, progress_bar, tokenizer, utils
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
......@@ -16,76 +16,67 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for generation!'
print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert args.replace_unk is None or args.raw_text, \
'--replace-unk requires a raw text dataset (--raw-text)'
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
if args.replace_unk is None:
dataset = data.load_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
else:
dataset = data.load_raw_text_dataset(
args.data,
[args.gen_subset],
args.source_lang,
args.target_lang,
)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
)
model.make_generation_fast_(beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded)
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader(
args.gen_subset,
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
max_positions=models[0].max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
).next_epoch_itr(shuffle=False)
# Initialize generator
gen_timer = StopwatchMeter()
if args.score_reference:
translator = SequenceScorer(models)
translator = SequenceScorer(models, task.target_dictionary)
else:
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling)
models, task.target_dictionary, beam_size=args.beam,
stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen, unk_penalty=args.unkpen,
sampling=args.sampling, sampling_topk=args.sampling_topk, minlen=args.min_len,
)
if use_cuda:
translator.cuda()
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
......@@ -94,21 +85,23 @@ def main(args):
else:
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size)
cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
)
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens,
args.remove_bpe,
escape_unk=True) if has_target else ''
src_str = src_dict.string(src_tokens, args.remove_bpe)
if has_target:
target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
......@@ -122,7 +115,7 @@ def main(args):
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
......@@ -145,20 +138,20 @@ def main(args):
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, dataset.dst_dict, add_if_not_exist=True)
target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if __name__ == '__main__':
parser = options.get_generation_parser()
args = parser.parse_args()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -6,30 +6,79 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import numpy as np
import sys
import torch
from torch.autograd import Variable
from fairseq import options, tokenizer, utils
from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
Batch = namedtuple('Batch', 'srcs tokens lengths')
Translation = namedtuple('Translation', 'src_str hypos alignments')
def buffered_read(buffer_size):
buffer = []
for src_str in sys.stdin:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(lines, args, src_dict, max_positions):
tokens = [
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
for src_str in lines
]
lengths = np.array([t.numel() for t in tokens])
itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens, lengths, src_dict),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
srcs=[lines[i] for i in batch['id']],
tokens=batch['net_input']['src_tokens'],
lengths=batch['net_input']['src_lengths'],
), batch['id']
def main(args):
print(args)
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.max_sentences is None:
args.max_sentences = 1
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert not args.max_sentences, \
'--max-sentences/--batch-size is not supported in interactive mode'
assert not args.max_sentences or args.max_sentences <= args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size'
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Setup task, e.g., translation
task = tasks.setup_task(args)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(':')
models, model_args = utils.load_ensemble_for_inference(model_paths, task)
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
......@@ -39,9 +88,12 @@ def main(args):
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
models, tgt_dict, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len,
)
if use_cuda:
translator.cuda()
......@@ -49,19 +101,12 @@ def main(args):
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
print('| Type the input sentence and press return:')
for src_str in sys.stdin:
src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_tokens.new([src_tokens.numel()])
translations = translator.generate(
Variable(src_tokens.view(1, -1)),
Variable(src_lengths.view(-1)),
def make_result(src_str, hypos):
result = Translation(
src_str='O\t{}'.format(src_str),
hypos=[],
alignments=[],
)
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
......@@ -70,14 +115,48 @@ def main(args):
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dst_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
result.alignments.append('A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment))))
return result
def process_batch(batch):
tokens = batch.tokens
lengths = batch.lengths
if use_cuda:
tokens = tokens.cuda()
lengths = lengths.cuda()
translations = translator.generate(
Variable(tokens),
Variable(lengths),
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size):
indices = []
results = []
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()):
indices.extend(batch_indices)
results += process_batch(batch)
for i in np.argsort(indices):
result = results[i]
print(result.src_str)
for hypo, align in zip(result.hypos, result.alignments):
print(hypo)
print(align)
if __name__ == '__main__':
parser = options.get_generation_parser()
args = parser.parse_args()
parser = options.get_generation_parser(interactive=True)
args = options.parse_args_and_arch(parser)
main(args)
......@@ -13,7 +13,7 @@ import torch
from fairseq import distributed_utils, options
from singleprocess_train import main as single_process_main
from train import main as single_process_main
def main(args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment