Unverified Commit d3795d6c authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes (#136)

Changes:
- 7d19e36: Add `--sampling` flag to generate.py to sample instead of doing beam search
- c777340: Add `scripts/average_checkpoints.py` to average multiple checkpoints into a combined model
- 3ea882c: Add `--max-update` option to train.py to stop training after a given number of updates
- small bugfixes for distributed training, LSTM, inverse square root LR scheduler
parent 48836525
......@@ -8,9 +8,11 @@
import math
import torch.nn.functional as F
from . import FairseqCriterion, register_criterion
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion):
......@@ -28,7 +30,7 @@ class CrossEntropyCriterion(FairseqCriterion):
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
target = model.get_targets(sample, net_output).view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
......@@ -47,6 +49,7 @@ class CrossEntropyCriterion(FairseqCriterion):
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'sample_size': sample_size,
}
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
......
......@@ -6,8 +6,6 @@
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn.functional as F
from fairseq import utils
......@@ -37,7 +35,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].unsqueeze(-1)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = model.get_targets(sample, net_output).view(-1, 1)
non_pad_mask = target.ne(self.padding_idx)
nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
......@@ -64,4 +63,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'sample_size': sample_size,
}
......@@ -198,7 +198,7 @@ class LanguagePairDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
res = { 'id': i, 'source': source }
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
......@@ -283,9 +283,9 @@ def _valid_size(src_size, dst_size, max_positions):
max_src_positions, max_dst_positions = max_positions, max_positions
else:
max_src_positions, max_dst_positions = max_positions
if src_size < 2 or src_size > max_src_positions:
if src_size < 1 or src_size > max_src_positions:
return False
if dst_size is not None and (dst_size < 2 or dst_size > max_dst_positions):
if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions):
return False
return True
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import math
import os
import torch
......@@ -23,6 +24,9 @@ class Dictionary(object):
self.unk_index = self.add_symbol(unk)
self.nspecial = len(self.symbols)
def __eq__(self, other):
return self.indices == other.indices
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
......@@ -97,8 +101,8 @@ class Dictionary(object):
"""Helper to get index of unk symbol"""
return self.unk_index
@staticmethod
def load(f):
@classmethod
def load(cls, f):
"""Loads the dictionary from a text file with the format:
```
......@@ -111,14 +115,14 @@ class Dictionary(object):
if isinstance(f, str):
try:
with open(f, 'r', encoding='utf-8') as fd:
return Dictionary.load(fd)
return cls.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except Exception:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
d = Dictionary()
d = cls()
for line in f.readlines():
idx = line.rfind(' ')
word = line[:idx]
......@@ -131,6 +135,7 @@ class Dictionary(object):
def save(self, f, threshold=3, nwords=-1):
"""Stores dictionary into a text file"""
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd, threshold, nwords)
cnt = 0
......
......@@ -10,6 +10,12 @@ import pickle
import torch.distributed
from fairseq import utils
def is_master(args):
return args.distributed_rank == 0
def distributed_init(args):
if args.distributed_world_size == 1:
......@@ -27,7 +33,7 @@ def distributed_init(args):
world_size=args.distributed_world_size)
args.distributed_rank = torch.distributed.get_rank()
if args.distributed_rank != 0:
if not is_master(args):
suppress_output()
return args.distributed_rank
......@@ -104,7 +110,7 @@ def all_gather_list(data, max_size=4096):
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.ByteTensor(max_size)
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
......@@ -113,18 +119,21 @@ def all_gather_list(data, max_size=4096):
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
if len(enc) >= max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(len(enc)))
in_buffer[0] = len(enc)
in_buffer[1:len(enc)+1] = torch.ByteTensor(list(enc))
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = out_buffer[0]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[1:size+1].tolist()))
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result
......@@ -107,10 +107,12 @@ class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
def __init__(self, path, dictionary):
def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = []
self.lines = []
self.sizes = []
self.append_eos = append_eos
self.reverse_order = reverse_order
self.read_data(path, dictionary)
self.size = len(self.tokens_list)
......@@ -118,8 +120,10 @@ class IndexedRawTextDataset(IndexedDataset):
with open(path, 'r') as f:
for line in f:
self.lines.append(line.strip('\n'))
# +1 for Lua compatibility
tokens = Tokenizer.tokenize(line, dictionary, add_if_not_exist=False) + 1
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility
self.tokens_list.append(tokens)
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
......
......@@ -21,12 +21,11 @@ class FairseqDecoder(nn.Module):
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1)
net_output1 = net_output.view(-1, vocab)
logits = net_output[0]
if log_probs:
return F.log_softmax(net_output1, dim=1).view_as(net_output)
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(net_output1, dim=1).view_as(net_output)
return F.softmax(logits, dim=-1)
def max_positions(self):
"""Maximum input length supported by the decoder."""
......
......@@ -41,13 +41,17 @@ class FairseqModel(nn.Module):
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out, _ = self.decoder(prev_output_tokens, encoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample['target']
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
......
......@@ -48,6 +48,11 @@ class FConvModel(FairseqModel):
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
encoder = FConvEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
......
......@@ -65,7 +65,7 @@ class LSTMModel(FairseqModel):
embed_dim=args.decoder_embed_dim,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
attention=bool(args.decoder_attention),
attention=bool(eval(args.decoder_attention)),
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
)
......@@ -178,7 +178,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
for layer in range(num_layers)
])
self.attention = AttentionLayer(encoder_embed_dim, embed_dim)
self.attention = AttentionLayer(encoder_embed_dim, embed_dim) if attention else None
if embed_dim != out_embed_dim:
self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
......@@ -229,7 +229,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
prev_cells[i] = cell
# apply attention using the last layer's hidden state
if self.attention is not None:
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs)
else:
out = hidden
out = F.dropout(out, p=self.dropout_out, training=self.training)
# input feeding
......
......@@ -8,8 +8,6 @@
import torch
from torch.nn.modules.utils import _single
from fairseq import utils
class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel)
......
......@@ -5,10 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch.autograd import Variable
import torch.nn as nn
from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
......@@ -25,27 +26,11 @@ class LearnedPositionalEmbedding(nn.Embedding):
"""Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = Variable(
input.data.new(1, 1).fill_(self.padding_idx + input.size(1)))
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = Variable(self.make_positions(input.data))
return super().forward(positions)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(Variable(positions))
def max_positions(self):
"""Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1
def make_positions(self, input):
"""Replace non-padding symbols with their position numbers."""
if not hasattr(self, 'range_buf'):
self.range_buf = input.new()
seqlen = input.size(1)
if self.range_buf.numel() < seqlen:
# offset positions by the padding index
torch.arange(self.padding_idx + 1, self.padding_idx + 1 + seqlen,
out=self.range_buf)
mask = input.ne(self.padding_idx)
positions = self.range_buf[:seqlen].expand_as(input)
if self.left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return input.clone().masked_scatter_(mask, positions[mask])
......@@ -16,6 +16,7 @@ OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
......
......@@ -23,6 +23,8 @@ class FairseqAdam(FairseqOptimizer):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
@property
def optimizer_config(self):
......@@ -35,6 +37,7 @@ class FairseqAdam(FairseqOptimizer):
return {
'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas),
'eps': self.args.adam_eps,
'weight_decay': self.args.weight_decay,
}
......
......@@ -69,7 +69,7 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr += self.lr_step
self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
else:
self.lr = self.decay_factor * num_updates**-0.5
self.optimizer.set_lr(self.lr)
......
......@@ -32,11 +32,12 @@ def get_generation_parser():
return parser
def parse_args_and_arch(parser, _args=None):
def parse_args_and_arch(parser, input_args=None):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
args, _ = parser.parse_known_args(_args)
# If input_args is given, we will parse those args instead of sys.argv.
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
model_specific_group = parser.add_argument_group(
......@@ -53,7 +54,7 @@ def parse_args_and_arch(parser, _args=None):
LR_SCHEDULER_REGISTRY[args.lr_scheduler].add_args(parser)
# Parse a second time.
args = parser.parse_args(_args)
args = parser.parse_args(input_args)
# Post-process args.
args.lr = list(map(float, args.lr.split(',')))
......@@ -140,6 +141,8 @@ def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--sentence-avg', action='store_true',
......@@ -188,12 +191,14 @@ def add_checkpoint_args(parser):
help='don\'t save models and checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
return group
def add_generation_args(parser):
group = parser.add_argument_group('Generation')
group.add_argument('--path', metavar='FILE', required=True, action='append',
group.add_argument('--path', metavar='FILE', action='append',
help='path(s) to model file(s)')
group.add_argument('--beam', default=5, type=int, metavar='N',
help='beam size')
......@@ -228,6 +233,8 @@ def add_generation_args(parser):
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
help=('initialize generation by target prefix of given length'))
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
return group
......
......@@ -107,9 +107,9 @@ class json_progress_bar(progress_bar):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch + float(i / size) if self.epoch is not None else None
update = self.epoch - 1 + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print('sweep_log: ' + json.dumps(stats), flush=True)
print(json.dumps(stats), flush=True)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
......
......@@ -15,7 +15,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0, retain_dropout=False):
unk_penalty=0, retain_dropout=False, sampling=False):
"""Generates translations of a given source sentence.
Args:
......@@ -36,7 +36,7 @@ class SequenceGenerator(object):
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size
self.minlen = minlen
max_decoder_len = min([m.max_decoder_positions() for m in self.models])
max_decoder_len = min(m.max_decoder_positions() for m in self.models)
max_decoder_len -= 1 # we define maxlen not including the EOS marker
self.maxlen = max_decoder_len if maxlen is None else min(maxlen, max_decoder_len)
self.stop_early = stop_early
......@@ -44,6 +44,7 @@ class SequenceGenerator(object):
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
def cuda(self):
for model in self.models:
......@@ -78,7 +79,7 @@ class SequenceGenerator(object):
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(sum([len(h[0]['tokens']) for h in hypos]))
timer.stop(sum(len(h[0]['tokens']) for h in hypos))
for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :]
# remove padding from ref
......@@ -255,9 +256,10 @@ class SequenceGenerator(object):
probs = probs.unfold(0, 1, beam_size).squeeze(2).contiguous()
scores = scores.type_as(probs)
scores_buf = scores_buf.type_as(probs)
else:
elif not self.sampling:
# make probs contain cumulative scores for each hypothesis
probs.add_(scores[:, step-1].view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
......@@ -278,6 +280,31 @@ class SequenceGenerator(object):
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
exp_probs = probs.exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
cand_indices.add_(2)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
if step == 0:
cand_beams = torch.zeros(bsz, cand_size).type_as(cand_indices)
else:
cand_beams = torch.arange(0, beam_size).repeat(bsz, 2).type_as(cand_indices)
# make scores cumulative
cand_scores.add_(
torch.gather(
scores[:, step-1].view(bsz, beam_size), dim=1,
index=cand_beams,
)
)
else:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
......@@ -411,8 +438,13 @@ class SequenceGenerator(object):
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out, incremental_states[model])
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else:
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None:
avg_probs = probs
else:
......
......@@ -63,10 +63,11 @@ class SequenceScorer(object):
net_input['src_tokens'],
net_input['src_lengths'],
)
decoder_out, attn = model.decoder(
decoder_out = model.decoder(
net_input['prev_output_tokens'],
encoder_out,
)
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None:
avg_probs = probs
......
......@@ -58,10 +58,14 @@ class Tokenizer:
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, consumer=None):
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = tokenize(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = dict.add_symbol(word)
......@@ -70,5 +74,6 @@ class Tokenizer:
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = dict.eos_index
return ids
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment