Commit c83efd21 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #33 from facebookresearch/oss-merge-internal

Changes:
Add support for NCCL v2
Add support for additional optimizers
SequenceGenerator returns attention matrix
Misc bugfixes (e.g., fixes #32) and cleanup
parents af86c1ac 104cead1
...@@ -102,3 +102,6 @@ ENV/ ...@@ -102,3 +102,6 @@ ENV/
# mypy # mypy
.mypy_cache/ .mypy_cache/
# Generated files
fairseq/temporal_convolution_tbc
...@@ -7,3 +7,5 @@ ...@@ -7,3 +7,5 @@
# #
from .multiprocessing_pdb import pdb from .multiprocessing_pdb import pdb
__all__ = ['pdb']
...@@ -97,7 +97,7 @@ class Scorer(object): ...@@ -97,7 +97,7 @@ class Scorer(object):
def result_string(self, order=4): def result_string(self, order=4):
assert order <= 4, "BLEU scores for order > 4 aren't supported" assert order <= 4, "BLEU scores for order > 4 aren't supported"
fmt = 'BLEU{} = {:2.2f}, {:2.1f}' fmt = 'BLEU{} = {:2.2f}, {:2.1f}'
for i in range(1, order): for _ in range(1, order):
fmt += '/{:2.1f}' fmt += '/{:2.1f}'
fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})' fmt += ' (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})'
bleup = [p * 100 for p in self.precision()[:order]] bleup = [p * 100 for p in self.precision()[:order]]
......
...@@ -104,11 +104,11 @@ void bleu_zero_init(bleu_stat* stat) { ...@@ -104,11 +104,11 @@ void bleu_zero_init(bleu_stat* stat) {
void bleu_one_init(bleu_stat* stat) { void bleu_one_init(bleu_stat* stat) {
bleu_zero_init(stat); bleu_zero_init(stat);
stat->count1 = 1; stat->count1 = 0;
stat->count2 = 1; stat->count2 = 1;
stat->count3 = 1; stat->count3 = 1;
stat->count4 = 1; stat->count4 = 1;
stat->match1 = 1; stat->match1 = 0;
stat->match2 = 1; stat->match2 = 1;
stat->match3 = 1; stat->match3 = 1;
stat->match4 = 1; stat->match4 = 1;
......
...@@ -18,14 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -18,14 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion):
super().__init__() super().__init__()
self.padding_idx = padding_idx self.padding_idx = padding_idx
def prepare(self, samples): def forward(self, model, sample):
self.denom = sum(s['ntokens'] if s else 0 for s in samples) """Compute the loss for the given sample.
def forward(self, net_output, sample): Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1)) input = net_output.view(-1, net_output.size(-1))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx) loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx)
return loss / self.denom sample_size = sample['ntokens']
logging_output = {
def aggregate(self, losses): 'loss': loss.data[0],
return sum(losses) / math.log(2) 'sample_size': sample_size,
}
return loss, sample_size, logging_output
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
...@@ -11,21 +11,25 @@ from torch.nn.modules.loss import _Loss ...@@ -11,21 +11,25 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss): class FairseqCriterion(_Loss):
def __init__(self, *args, **kwargs): def __init__(self):
super().__init__(*args, **kwargs) super().__init__()
def prepare(self, samples): def forward(self, model, sample):
"""Prepare criterion for DataParallel training.""" """Compute the loss for the given sample.
raise NotImplementedError
def forward(self, net_output, sample): Returns a tuple with three elements:
"""Compute the loss for the given sample and network output.""" 1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise NotImplementedError raise NotImplementedError
def aggregate(self, losses): @staticmethod
"""Aggregate losses from DataParallel training. def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
Takes a list of losses as input (as returned by forward) and
aggregates them into the total loss for the mini-batch.
"""
raise NotImplementedError raise NotImplementedError
@staticmethod
def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes."""
return sum(sample_sizes)
...@@ -49,14 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -49,14 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.weights = weights self.weights = weights
def prepare(self, samples): def forward(self, model, sample):
self.denom = sum(s['ntokens'] if s else 0 for s in samples) """Compute the loss for the given sample.
def forward(self, net_output, sample): Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1))) input = F.log_softmax(net_output.view(-1, net_output.size(-1)))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights) loss = LabelSmoothedCrossEntropy.apply(input, target, self.eps, self.padding_idx, self.weights)
return loss / self.denom sample_size = sample['ntokens']
logging_output = {
'loss': loss.data[0],
'sample_size': sample_size,
}
return loss, sample_size, logging_output
def aggregate(self, losses): @staticmethod
return sum(losses) / math.log(2) def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
...@@ -18,30 +18,32 @@ from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset ...@@ -18,30 +18,32 @@ from fairseq.indexed_dataset import IndexedDataset, IndexedInMemoryDataset
def load_with_check(path, load_splits, src=None, dst=None): def load_with_check(path, load_splits, src=None, dst=None):
"""Loads the train, valid, and test sets from the specified folder """Loads specified data splits (e.g., test, train or valid) from the
and check that training files exist.""" specified folder and check that files exist."""
def find_language_pair(files): def find_language_pair(files):
for split in load_splits:
for filename in files: for filename in files:
parts = filename.split('.') parts = filename.split('.')
if parts[0] == 'train' and parts[-1] == 'idx': if parts[0] == split and parts[-1] == 'idx':
return parts[1].split('-') return parts[1].split('-')
def train_file_exists(src, dst): def split_exists(split, src, dst):
filename = 'train.{0}-{1}.{0}.idx'.format(src, dst) filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
return os.path.exists(os.path.join(path, filename)) return os.path.exists(os.path.join(path, filename))
if src is None and dst is None: if src is None and dst is None:
# find language pair automatically # find language pair automatically
src, dst = find_language_pair(os.listdir(path)) src, dst = find_language_pair(os.listdir(path))
elif train_file_exists(src, dst):
# check for src-dst langcode if not split_exists(load_splits[0], src, dst):
pass # try reversing src and dst
elif train_file_exists(dst, src):
# check for dst-src langcode
src, dst = dst, src src, dst = dst, src
else:
raise ValueError('training file not found for {}-{}'.format(src, dst)) for split in load_splits:
if not split_exists(load_splits[0], src, dst):
raise ValueError('Data split not found: {}-{} ({})'.format(
src, dst, split))
dataset = load(path, load_splits, src, dst) dataset = load(path, load_splits, src, dst)
return dataset return dataset
...@@ -326,7 +328,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p ...@@ -326,7 +328,7 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, max_p
batches = result batches = result
else: else:
for i in range(epoch - 1): for _ in range(epoch - 1):
np.random.shuffle(batches) np.random.shuffle(batches)
return batches return batches
......
...@@ -38,13 +38,31 @@ class Dictionary(object): ...@@ -38,13 +38,31 @@ class Dictionary(object):
return self.indices[sym] return self.indices[sym]
return self.unk_index return self.unk_index
def string(self, tensor): def string(self, tensor, bpe_symbol=None, escape_unk=False):
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
"""
if torch.is_tensor(tensor) and tensor.dim() == 2: if torch.is_tensor(tensor) and tensor.dim() == 2:
sentences = [self.string(line) for line in tensor] return '\n'.join(self.to_string(t) for t in tensor)
return '\n'.join(sentences)
def token_string(i):
if i == self.unk():
return self.unk_string(escape_unk)
else:
return self[i]
eos = self.eos() sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
return ' '.join([self[i] for i in tensor if i != eos]) if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
if escape:
return '<{}>'.format(self.unk_word)
else:
return self.unk_word
def add_symbol(self, word, n=1): def add_symbol(self, word, n=1):
"""Adds a word to the dictionary""" """Adds a word to the dictionary"""
......
...@@ -15,12 +15,18 @@ from fairseq.modules import BeamableMM, LinearizedConvolution ...@@ -15,12 +15,18 @@ from fairseq.modules import BeamableMM, LinearizedConvolution
class FConvModel(nn.Module): class FConvModel(nn.Module):
def __init__(self, encoder, decoder, padding_idx=1): def __init__(self, encoder, decoder):
super(FConvModel, self).__init__() super(FConvModel, self).__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention]) self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self.padding_idx = padding_idx
self._is_generation_fast = False self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions): def forward(self, src_tokens, src_positions, input_tokens, input_positions):
...@@ -67,11 +73,15 @@ class FConvModel(nn.Module): ...@@ -67,11 +73,15 @@ class FConvModel(nn.Module):
class Encoder(nn.Module): class Encoder(nn.Module):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, num_embeddings, embed_dim=512, max_positions=1024, def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1, padding_idx=1): convolutions=((512, 3),) * 20, dropout=0.1):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
...@@ -160,10 +170,11 @@ class AttentionLayer(nn.Module): ...@@ -160,10 +170,11 @@ class AttentionLayer(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, num_embeddings, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, padding_idx=1): attention=True, dropout=0.1):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
...@@ -171,8 +182,11 @@ class Decoder(nn.Module): ...@@ -171,8 +182,11 @@ class Decoder(nn.Module):
# expand True into [True, True, ...] and do the same with False # expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions) attention = [attention] * len(convolutions)
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = Embedding(max_positions, embed_dim, padding_idx)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList() self.convolutions = nn.ModuleList()
...@@ -284,9 +298,8 @@ class Decoder(nn.Module): ...@@ -284,9 +298,8 @@ class Decoder(nn.Module):
'already performing incremental inference' 'already performing incremental inference'
self._is_inference_incremental = True self._is_inference_incremental = True
# save original forward and convolution layers # save original forward
self._orig_forward = self.forward self._orig_forward = self.forward
self._orig_conv = self.convolutions
# switch to incremental forward # switch to incremental forward
self.forward = self._incremental_forward self.forward = self._incremental_forward
...@@ -295,9 +308,8 @@ class Decoder(nn.Module): ...@@ -295,9 +308,8 @@ class Decoder(nn.Module):
self.start_fresh_sequence(beam_size) self.start_fresh_sequence(beam_size)
def _stop_incremental_inference(self): def _stop_incremental_inference(self):
# restore original forward and convolution layers # restore original forward
self.forward = self._orig_forward self.forward = self._orig_forward
self.convolutions = self._orig_conv
self._is_inference_incremental = False self._is_inference_incremental = False
...@@ -505,24 +517,21 @@ def parse_arch(args): ...@@ -505,24 +517,21 @@ def parse_arch(args):
return args return args
def build_model(args, dataset): def build_model(args, src_dict, dst_dict):
padding_idx = dataset.dst_dict.pad()
encoder = Encoder( encoder = Encoder(
len(dataset.src_dict), src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
padding_idx=padding_idx,
max_positions=args.max_positions, max_positions=args.max_positions,
) )
decoder = Decoder( decoder = Decoder(
len(dataset.dst_dict), dst_dict,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
padding_idx=padding_idx,
max_positions=args.max_positions, max_positions=args.max_positions,
) )
return FConvModel(encoder, decoder, padding_idx) return FConvModel(encoder, decoder)
...@@ -15,7 +15,6 @@ import torch ...@@ -15,7 +15,6 @@ import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils from fairseq import nccl, utils
from fairseq.criterions import FairseqCriterion
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG from fairseq.nag import NAG
...@@ -32,7 +31,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -32,7 +31,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
(prefixed with `_async_`), which run on each process in parallel. (prefixed with `_async_`), which run on each process in parallel.
""" """
def __init__(self, args, model, device_ids=None, OPTIMIZERS = ['adagrad', 'adam', 'nag', 'sgd']
def __init__(self, args, model, criterion, device_ids=None,
multiprocessing_method='spawn'): multiprocessing_method='spawn'):
if device_ids is None: if device_ids is None:
device_ids = tuple(range(torch.cuda.device_count())) device_ids = tuple(range(torch.cuda.device_count()))
...@@ -42,40 +43,57 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -42,40 +43,57 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
model = model.share_memory() model = model.share_memory()
nccl_uid = nccl.get_unique_id() nccl_uid = nccl.get_unique_id()
self.criterion = criterion
Future.gen_list([ Future.gen_list([
self.call_async(rank, '_async_init', args=args, model=model, self.call_async(rank, '_async_init', args=args, model=model,
nccl_uid=nccl_uid) criterion=criterion, nccl_uid=nccl_uid)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
self._grads_initialized = False self._grads_initialized = False
def _async_init(self, rank, device_id, args, model, nccl_uid): def _async_init(self, rank, device_id, args, model, criterion, nccl_uid):
"""Initialize child processes.""" """Initialize child processes."""
self.args = args self.args = args
# set torch.seed in this process
torch.manual_seed(args.seed)
# set CUDA device # set CUDA device
torch.cuda.set_device(device_id) torch.cuda.set_device(device_id)
# initialize NCCL # initialize NCCL
nccl.initialize(self.num_replicas, nccl_uid, device_id) nccl.initialize(self.num_replicas, nccl_uid, device_id)
# copy model to current device # copy model and criterion to current device
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize optimizer # initialize optimizer
self.optimizer = NAG(self.model.parameters(), lr=self.args.lr, self.optimizer = self._build_optimizer()
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self.flat_grads = None self.flat_grads = None
self.loss = None
# initialize LR scheduler # initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler() self.lr_scheduler = self._build_lr_scheduler()
def _build_optimizer(self):
if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr,
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'adam':
return torch.optim.Adam(self.model.parameters(), lr=self.args.lr,
betas=eval(self.args.adam_betas),
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'nag':
return NAG(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'sgd':
return torch.optim.SGD(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
def _build_lr_scheduler(self): def _build_lr_scheduler(self):
if self.args.force_anneal > 0: if self.args.force_anneal > 0:
def anneal(e): def anneal(e):
...@@ -98,14 +116,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -98,14 +116,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _async_get_model(self, rank, device_id): def _async_get_model(self, rank, device_id):
return self.model return self.model
def save_checkpoint(self, args, epoch, batch_offset, val_loss=None): def save_checkpoint(self, filename, extra_state):
"""Save a checkpoint for the current model.""" """Save a checkpoint for the current model."""
self.call_async(0, '_async_save_checkpoint', args=args, epoch=epoch, self.call_async(0, '_async_save_checkpoint', filename=filename, extra_state=extra_state).gen()
batch_offset=batch_offset, val_loss=val_loss).gen()
def _async_save_checkpoint(self, rank, device_id, args, epoch, batch_offset, val_loss): def _async_save_checkpoint(self, rank, device_id, filename, extra_state):
utils.save_checkpoint(args, epoch, batch_offset, self.model, utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.optimizer, self.lr_scheduler, val_loss) self.lr_scheduler, self._optim_history, extra_state)
def load_checkpoint(self, filename): def load_checkpoint(self, filename):
"""Load a checkpoint into the model replicas in each process.""" """Load a checkpoint into the model replicas in each process."""
...@@ -113,17 +130,26 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -113,17 +130,26 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.call_async(rank, '_async_load_checkpoint', filename=filename) self.call_async(rank, '_async_load_checkpoint', filename=filename)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
epoch, batch_offset = results[0] extra_state = results[0]
return epoch, batch_offset return extra_state
def _async_load_checkpoint(self, rank, device_id, filename): def _async_load_checkpoint(self, rank, device_id, filename):
return utils.load_checkpoint(filename, self.model, self.optimizer, extra_state, self._optim_history = utils.load_state(
filename, self.model, self.criterion, self.optimizer,
self.lr_scheduler, cuda_device=device_id) self.lr_scheduler, cuda_device=device_id)
return extra_state
def train_step(self, samples, criterion): def set_seed(self, seed):
"""Do forward, backward and gradient step in parallel.""" Future.gen_list([
assert isinstance(criterion, FairseqCriterion) self.call_async(rank, '_async_set_seed', seed=seed)
for rank in range(self.num_replicas)
])
def _async_set_seed(self, rank, device_id, seed):
torch.manual_seed(seed)
def train_step(self, samples):
"""Do forward, backward and gradient step in parallel."""
# PyTorch initializes gradient buffers lazily, so the first # PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas # train step needs to send non-empty samples to all replicas
replace_empty_samples = False replace_empty_samples = False
...@@ -133,33 +159,45 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -133,33 +159,45 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs # scatter sample across GPUs
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples) self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
criterion.prepare(samples)
# forward pass, backward pass and gradient step # forward pass
losses = [ sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_train_step', criterion=criterion) self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
] ])
# aggregate losses and gradient norms # backward pass, all-reduce gradients and take an optimization step
losses, grad_norms = Future.gen_tuple_list(losses) grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
loss = criterion.aggregate(losses) grad_norms = Future.gen_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
])
return loss, grad_norms[0] # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
def _async_train_step(self, rank, device_id, criterion): return logging_output
self.model.train()
# zero grads even if net_input is None, since we will all-reduce them def _async_forward(self, rank, device_id, eval=False):
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
# calculate loss and grads if self._sample is None:
loss = 0 return 0, {}
if self._sample is not None:
net_output = self.model(**self._sample['net_input']) # calculate loss and sample size
loss_ = criterion(net_output, self._sample) self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
loss_.backward()
loss = loss_.data[0] return sample_size, logging_output
def _async_backward_and_opt(self, rank, device_id, grad_denom):
if self.loss is not None:
# backward pass
self.loss.backward()
# flatten grads into a contiguous block of memory # flatten grads into a contiguous block of memory
if self.flat_grads is None: if self.flat_grads is None:
...@@ -168,13 +206,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -168,13 +206,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# all-reduce grads # all-reduce grads
nccl.all_reduce(self.flat_grads) nccl.all_reduce(self.flat_grads)
# normalize grads
if grad_denom != 0:
self.flat_grads.div_(grad_denom)
# clip grads # clip grads
grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm) grad_norm = self._clip_grads_(self.flat_grads, self.args.clip_norm)
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
return loss, grad_norm # reset loss
self.loss = None
return grad_norm
def _flatten_grads_(self, model): def _flatten_grads_(self, model):
num_params = sum(p.data.numel() for p in model.parameters()) num_params = sum(p.data.numel() for p in model.parameters())
...@@ -196,30 +241,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -196,30 +241,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
flat_grads.div_(coef) flat_grads.div_(coef)
return norm return norm
def valid_step(self, samples, criterion): def valid_step(self, samples):
"""Do forward pass in parallel.""" """Do forward pass in parallel."""
# scatter sample across GPUs # scatter sample across GPUs
self._scatter_samples(samples, volatile=True) self._scatter_samples(samples, volatile=True)
criterion.prepare(samples)
# forward pass # forward pass
losses = [ _sample_sizes, logging_outputs = Future.gen_tuple_list([
self.call_async(rank, '_async_valid_step', criterion=criterion) self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
] ])
# aggregate losses
loss = criterion.aggregate(Future.gen_list(losses))
return loss # aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
def _async_valid_step(self, rank, device_id, criterion): return logging_output
if self._sample is None:
return 0
self.model.eval()
net_output = self.model(**self._sample['net_input'])
loss = criterion(net_output, self._sample)
return loss.data[0]
def get_lr(self): def get_lr(self):
"""Get the current learning rate.""" """Get the current learning rate."""
......
...@@ -12,9 +12,10 @@ GPU separately. ...@@ -12,9 +12,10 @@ GPU separately.
""" """
import ctypes import ctypes
import warnings from ctypes.util import find_library
lib = None lib = None
nccl_2_0 = None
_uid = None _uid = None
_rank = None _rank = None
_num_devices = None _num_devices = None
...@@ -22,48 +23,25 @@ _comm = None ...@@ -22,48 +23,25 @@ _comm = None
__all__ = ['all_reduce', 'initialize', 'get_unique_id'] __all__ = ['all_reduce', 'initialize', 'get_unique_id']
def _libnccl():
global lib
if lib is None:
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
lib.ncclGetErrorString.restype = ctypes.c_char_p
else:
lib = None
return lib
def is_available(tensors):
devices = set()
for tensor in tensors:
if not tensor.is_contiguous():
return False
if not tensor.is_cuda:
return False
device = tensor.get_device()
if device in devices:
return False
devices.add(device)
if _libnccl() is None:
warnings.warn('NCCL library not found. Check your LD_LIBRARY_PATH')
return False
return True
_communicators = {}
# ncclDataType_t # ncclDataType_t
ncclChar = 0 nccl_types = {
ncclInt = 1 'torch.cuda.ByteTensor': 0,
ncclHalf = 2 'torch.cuda.CharTensor': 0,
ncclFloat = 3 'torch.cuda.IntTensor': 1,
ncclDouble = 4 'torch.cuda.HalfTensor': 2,
ncclInt64 = 5 'torch.cuda.FloatTensor': 3,
ncclUint64 = 6 'torch.cuda.DoubleTensor': 4,
'torch.cuda.LongTensor': 5,
}
nccl_types_2_0 = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 2,
'torch.cuda.HalfTensor': 6,
'torch.cuda.FloatTensor': 7,
'torch.cuda.DoubleTensor': 8,
'torch.cuda.LongTensor': 4,
}
# ncclRedOp_t # ncclRedOp_t
SUM = 0 SUM = 0
...@@ -71,21 +49,57 @@ PROD = 1 ...@@ -71,21 +49,57 @@ PROD = 1
MAX = 2 MAX = 2
MIN = 3 MIN = 3
nccl_types = { status_codes_2_0 = {
'torch.cuda.ByteTensor': ncclChar, 0: "Success",
'torch.cuda.CharTensor': ncclChar, 1: "Unhandled Cuda Error",
'torch.cuda.IntTensor': ncclInt, 2: "System Error",
'torch.cuda.HalfTensor': ncclHalf, 3: "Internal Error",
'torch.cuda.FloatTensor': ncclFloat, 4: "Invalid Argument Error",
'torch.cuda.DoubleTensor': ncclDouble, 5: "Invalid Usage Error",
'torch.cuda.LongTensor': ncclInt64, }
status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
} }
def _libnccl():
global nccl_2_0
global lib
global status_codes
global nccl_types
if lib is None:
lib = ctypes.pydll.LoadLibrary(find_library('nccl'))
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
else:
lib = None
if hasattr(lib, 'ncclGroupStart'):
nccl_2_0 = True
status_codes = status_codes_2_0
nccl_types = nccl_types_2_0
return lib
class NcclError(RuntimeError): class NcclError(RuntimeError):
def __init__(self, status): def __init__(self, status):
self.status = status self.status = status
msg = '{0} ({1})'.format(lib.ncclGetErrorString(status), status) msg = '{0} ({1})'.format(status_codes.get(status), status)
super(NcclError, self).__init__(msg) super(NcclError, self).__init__(msg)
...@@ -134,10 +148,12 @@ def initialize(num_devices, uid, rank): ...@@ -134,10 +148,12 @@ def initialize(num_devices, uid, rank):
def communicator(): def communicator():
global _comm global _comm
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
if _uid is None: if _uid is None:
raise RuntimeError('NCCL not initialized') raise RuntimeError('NCCL not initialized')
if _comm is None: if _comm is None:
comm = ctypes.c_void_p() comm = NcclComm()
check_error(lib.ncclCommInitRank( check_error(lib.ncclCommInitRank(
ctypes.byref(comm), ctypes.byref(comm),
ctypes.c_int(_num_devices), ctypes.c_int(_num_devices),
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import argparse import argparse
from fairseq import models from fairseq import models
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
def get_parser(desc): def get_parser(desc):
...@@ -41,6 +42,9 @@ def add_dataset_args(parser): ...@@ -41,6 +42,9 @@ def add_dataset_args(parser):
def add_optimization_args(parser): def add_optimization_args(parser):
group = parser.add_argument_group('Optimization') group = parser.add_argument_group('Optimization')
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=MultiprocessingTrainer.OPTIMIZERS,
help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS)))
group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR', group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR',
help='initial learning rate') help='initial learning rate')
group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float, group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float,
...@@ -53,6 +57,8 @@ def add_optimization_args(parser): ...@@ -53,6 +57,8 @@ def add_optimization_args(parser):
help='learning rate shrink factor for annealing, lr_new = (lr * lrshrink)') help='learning rate shrink factor for annealing, lr_new = (lr * lrshrink)')
group.add_argument('--momentum', default=0.99, type=float, metavar='M', group.add_argument('--momentum', default=0.99, type=float, metavar='M',
help='momentum factor') help='momentum factor')
group.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
group.add_argument('--clip-norm', default=25, type=float, metavar='NORM', group.add_argument('--clip-norm', default=25, type=float, metavar='NORM',
help='clip threshold of gradients') help='clip threshold of gradients')
group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', group.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
...@@ -85,13 +91,13 @@ def add_generation_args(parser): ...@@ -85,13 +91,13 @@ def add_generation_args(parser):
help='beam size') help='beam size')
group.add_argument('--nbest', default=1, type=int, metavar='N', group.add_argument('--nbest', default=1, type=int, metavar='N',
help='number of hypotheses to output') help='number of hypotheses to output')
group.add_argument('--max-len-a', default=0, type=int, metavar='N', group.add_argument('--max-len-a', default=0, type=float, metavar='N',
help=('generate sequence of maximum length ax + b, ' help=('generate sequences of maximum length ax + b, '
'where x is the source length')) 'where x is the source length'))
group.add_argument('--max-len-b', default=200, type=int, metavar='N', group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequence of maximum length ax + b, ' help=('generate sequences of maximum length ax + b, '
'where x is the source length')) 'where x is the source length'))
group.add_argument('--remove-bpe', action='store_true', group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring') help='remove BPE tokens before scoring')
group.add_argument('--no-early-stop', action='store_true', group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam ' help=('continue searching even after finalizing k=beam '
......
...@@ -16,7 +16,7 @@ from fairseq import utils ...@@ -16,7 +16,7 @@ from fairseq import utils
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, dst_dict, beam_size=1, minlen=1, maxlen=200, def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1): stop_early=True, normalize_scores=True, len_penalty=1):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
...@@ -29,13 +29,14 @@ class SequenceGenerator(object): ...@@ -29,13 +29,14 @@ class SequenceGenerator(object):
normalize_scores: Normalize scores by the length of the output. normalize_scores: Normalize scores by the length of the output.
""" """
self.models = models self.models = models
self.dict = dst_dict self.pad = models[0].dst_dict.pad()
self.pad = dst_dict.pad() self.eos = models[0].dst_dict.eos()
self.eos = dst_dict.eos() assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
self.vocab_size = len(dst_dict) assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
self.maxlen = min(maxlen, *(m.decoder.max_positions() - self.pad - 2 for m in self.models)) self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2)) self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size() self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early self.stop_early = stop_early
...@@ -48,7 +49,7 @@ class SequenceGenerator(object): ...@@ -48,7 +49,7 @@ class SequenceGenerator(object):
self.positions = self.positions.cuda() self.positions = self.positions.cuda()
return self return self
def generate_batched_itr(self, data_itr, maxlen_a=0, maxlen_b=200, def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200,
cuda_device=None, timer=None): cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
...@@ -69,7 +70,7 @@ class SequenceGenerator(object): ...@@ -69,7 +70,7 @@ class SequenceGenerator(object):
if timer is not None: if timer is not None:
timer.start() timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'], hypos = self.generate(input['src_tokens'], input['src_positions'],
maxlen=(maxlen_a*srclen + maxlen_b)) maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id']): for i, id in enumerate(s['id']):
...@@ -91,7 +92,7 @@ class SequenceGenerator(object): ...@@ -91,7 +92,7 @@ class SequenceGenerator(object):
# the max beam size is the dictionary size - 1, since we never select pad # the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, len(self.dict) - 1) beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
...@@ -108,8 +109,8 @@ class SequenceGenerator(object): ...@@ -108,8 +109,8 @@ class SequenceGenerator(object):
tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad) tokens = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(self.pad)
tokens_buf = tokens.clone() tokens_buf = tokens.clone()
tokens[:, 0] = self.eos tokens[:, 0] = self.eos
align = src_tokens.data.new(bsz * beam_size, maxlen + 2).fill_(-1) attn = scores.new(bsz * beam_size, src_tokens.size(1), maxlen + 2)
align_buf = align.clone() attn_buf = attn.clone()
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
...@@ -126,7 +127,7 @@ class SequenceGenerator(object): ...@@ -126,7 +127,7 @@ class SequenceGenerator(object):
# helper function for allocating buffers on the fly # helper function for allocating buffers on the fly
buffers = {} buffers = {}
def buffer(name, type_of=tokens): def buffer(name, type_of=tokens): # noqa
if name not in buffers: if name not in buffers:
buffers[name] = type_of.new() buffers[name] = type_of.new()
return buffers[name] return buffers[name]
...@@ -177,10 +178,12 @@ class SequenceGenerator(object): ...@@ -177,10 +178,12 @@ class SequenceGenerator(object):
def get_hypo(): def get_hypo():
hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
hypo[step] = self.eos hypo[step] = self.eos
alignment = align[idx, 1:step+2].clone() attention = attn[idx, :, 1:step+2].clone()
_, alignment = attention.max(dim=0)
return { return {
'tokens': hypo, 'tokens': hypo,
'score': score, 'score': score,
'attention': attention,
'alignment': alignment, 'alignment': alignment,
} }
...@@ -224,9 +227,8 @@ class SequenceGenerator(object): ...@@ -224,9 +227,8 @@ class SequenceGenerator(object):
probs.add_(scores.view(-1, 1)) probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad probs[:, self.pad] = -math.inf # never select pad
# record alignment to source tokens, based on attention # Record attention scores
_ignore_scores = buffer('_ignore_scores', type_of=scores) attn[:, :, step+1].copy_(avg_attn_scores)
avg_attn_scores.topk(1, out=(_ignore_scores, align[:, step+1].unsqueeze(1)))
# take the best 2 x beam_size predictions. We'll choose the first # take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with. # beam_size of these which don't predict eos to continue with.
...@@ -290,17 +292,17 @@ class SequenceGenerator(object): ...@@ -290,17 +292,17 @@ class SequenceGenerator(object):
cand_indices.gather(1, active_hypos, cand_indices.gather(1, active_hypos,
out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1]) out=tokens_buf.view(bsz, beam_size, -1)[:, :, step+1])
# copy attention/alignment for active hypotheses # copy attention for active hypotheses
torch.index_select(align[:, :step+2], dim=0, index=active_bbsz_idx, torch.index_select(attn[:, :, :step+2], dim=0, index=active_bbsz_idx,
out=align_buf[:, :step+2]) out=attn_buf[:, :, :step+2])
# swap buffers # swap buffers
old_tokens = tokens old_tokens = tokens
tokens = tokens_buf tokens = tokens_buf
tokens_buf = old_tokens tokens_buf = old_tokens
old_align = align old_attn = attn
align = align_buf attn = attn_buf
align_buf = old_align attn_buf = old_attn
# reorder incremental state in decoder # reorder incremental state in decoder
reorder_state = active_bbsz_idx reorder_state = active_bbsz_idx
......
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
from collections import Counter
import re import re
import torch import torch
from fairseq import dictionary from fairseq import dictionary
...@@ -32,46 +34,41 @@ class Tokenizer: ...@@ -32,46 +34,41 @@ class Tokenizer:
@staticmethod @staticmethod
def add_file_to_dictionary(filename, dict, tokenize): def add_file_to_dictionary(filename, dict, tokenize):
with open(filename, 'r') as f: with open(filename, 'r') as f:
for line in f.readlines(): for line in f:
for word in tokenize(line): for word in tokenize(line):
dict.add_symbol(word) dict.add_symbol(word)
dict.add_symbol(dict.eos_word) dict.add_symbol(dict.eos_word)
@staticmethod @staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line): def binarize(filename, dict, consumer, tokenize=tokenize_line):
nseq, ntok, nunk = 0, 0, 0 nseq, ntok = 0, 0
replaced = {} replaced = Counter()
with open(filename, 'r') as f:
for line in f.readlines(): def replaced_consumer(word, idx):
words = tokenize(line)
nwords = len(words)
ids = torch.IntTensor(nwords + 1)
nseq = nseq + 1
for i in range(0, len(words)):
word = words[i]
idx = dict.index(word)
if idx == dict.unk_index and word != dict.unk_word: if idx == dict.unk_index and word != dict.unk_word:
nunk = nunk + 1 replaced.update([word])
if word in replaced:
replaced[word] = replaced[word] + 1 with open(filename, 'r') as f:
else: for line in f:
replaced[word] = 1 ids = Tokenizer.tokenize(line, dict, tokenize, add_if_not_exist=False, consumer=replaced_consumer)
ids[i] = idx nseq += 1
ids[nwords] = dict.eos_index
consumer(ids) consumer(ids)
ntok = ntok + len(ids) ntok += len(ids)
return {'nseq': nseq, 'nunk': nunk, 'ntok': ntok, 'replaced': len(replaced)} return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': len(replaced)}
@staticmethod @staticmethod
def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True): def tokenize(line, dict, tokenize=tokenize_line, add_if_not_exist=True, consumer=None):
words = tokenize(line) words = tokenize(line)
nwords = len(words) nwords = len(words)
ids = torch.IntTensor(nwords + 1) ids = torch.IntTensor(nwords + 1)
for i in range(0, len(words)): for i, word in enumerate(words):
if add_if_not_exist: if add_if_not_exist:
ids[i] = dict.add_symbol(words[i]) idx = dict.add_symbol(word)
else: else:
ids[i] = dict.index(words[i]) idx = dict.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
ids[nwords] = dict.eos_index ids[nwords] = dict.eos_index
return ids return ids
...@@ -14,7 +14,7 @@ import traceback ...@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, data, models from fairseq import criterions, models
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
...@@ -24,13 +24,13 @@ def parse_args_and_arch(parser): ...@@ -24,13 +24,13 @@ def parse_args_and_arch(parser):
return args return args
def build_model(args, dataset): def build_model(args, src_dict, dst_dict):
assert hasattr(models, args.model), 'Missing model type' assert hasattr(models, args.model), 'Missing model type'
return getattr(models, args.model).build_model(args, dataset) return getattr(models, args.model).build_model(args, src_dict, dst_dict)
def build_criterion(args, dataset): def build_criterion(args, src_dict, dst_dict):
padding_idx = dataset.dst_dict.pad() padding_idx = dst_dict.pad()
if args.label_smoothing > 0: if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx) return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx)
else: else:
...@@ -41,40 +41,34 @@ def torch_persistent_save(*args, **kwargs): ...@@ -41,40 +41,34 @@ def torch_persistent_save(*args, **kwargs):
for i in range(3): for i in range(3):
try: try:
return torch.save(*args, **kwargs) return torch.save(*args, **kwargs)
except: except Exception:
if i == 2: if i == 2:
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, val_loss=None): def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_history=None, extra_state=None):
if optim_history is None:
optim_history = []
if extra_state is None:
extra_state = {}
state_dict = { state_dict = {
'args': args, 'args': args,
'epoch': epoch,
'batch_offset': batch_offset,
'model': model.state_dict(), 'model': model.state_dict(),
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'best_loss': lr_scheduler.best, 'best_loss': lr_scheduler.best,
'val_loss': val_loss,
} }
],
if batch_offset == 0: 'extra_state': extra_state,
if not args.no_epoch_checkpoints: }
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch)) torch_persistent_save(state_dict, filename)
torch_persistent_save(state_dict, epoch_filename)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
torch_persistent_save(state_dict, best_filename)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
torch_persistent_save(state_dict, last_filename)
def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None): def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
if not os.path.exists(filename): if not os.path.exists(filename):
return 1, 0 return None, []
if cuda_device is None: if cuda_device is None:
state = torch.load(filename) state = torch.load(filename)
else: else:
...@@ -82,19 +76,48 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None): ...@@ -82,19 +76,48 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
filename, filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device)) map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
) )
state = _upgrade_state_dict(state)
# load model parameters
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
lr_scheduler.best = state['best_loss']
epoch = state['epoch'] + 1
batch_offset = state['batch_offset']
gpu_str = ' on GPU #{}'.format(cuda_device) if cuda_device is not None else '' # only load optimizer and lr_scheduler if they match with the checkpoint
print('| loaded checkpoint {} (epoch {}){}'.format(filename, epoch, gpu_str)) optim_history = state['optimizer_history']
return epoch, batch_offset last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(last_optim['optimizer'])
lr_scheduler.best = last_optim['best_loss']
return state['extra_state'], optim_history
def _upgrade_state_dict(state):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if 'optimizer_history' not in state:
state['optimizer_history'] = [
{
'criterion_name': criterions.CrossEntropyCriterion.__name__,
'optimizer': state['optimizer'],
'best_loss': state['best_loss'],
},
]
del state['optimizer']
del state['best_loss']
# move extra_state into sub-dictionary
if 'epoch' in state and 'extra_state' not in state:
state['extra_state'] = {
'epoch': state['epoch'],
'batch_offset': state['batch_offset'],
'val_loss': state['val_loss'],
}
del state['epoch']
del state['batch_offset']
del state['val_loss']
return state
def load_ensemble_for_inference(filenames, data_path, split): def load_ensemble_for_inference(filenames, src_dict, dst_dict):
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -103,19 +126,15 @@ def load_ensemble_for_inference(filenames, data_path, split): ...@@ -103,19 +126,15 @@ def load_ensemble_for_inference(filenames, data_path, split):
states.append( states.append(
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
) )
# load dataset
args = states[0]['args'] args = states[0]['args']
dataset = data.load(data_path, [split], args.source_lang, args.target_lang)
# build models # build ensemble
ensemble = [] ensemble = []
for state in states: for state in states:
model = build_model(args, dataset) model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
ensemble.append(model) ensemble.append(model)
return ensemble
return ensemble, dataset
def prepare_sample(sample, volatile=False, cuda_device=None): def prepare_sample(sample, volatile=False, cuda_device=None):
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
...@@ -10,7 +11,7 @@ import sys ...@@ -10,7 +11,7 @@ import sys
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import bleu, options, utils, tokenizer from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -36,9 +37,15 @@ def main(): ...@@ -36,9 +37,15 @@ def main():
progress_bar.enabled = False progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load model and dataset # Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models, dataset = utils.load_ensemble_for_inference(args.path, args.data, args.gen_subset) models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
...@@ -49,19 +56,23 @@ def main(): ...@@ -49,19 +56,23 @@ def main():
# ignore too long sentences # ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models)) args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize model for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(not args.no_beamable_mm) model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator # Initialize generator
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam, translator = SequenceGenerator(
stop_early=(not args.no_early_stop), models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
len_penalty=args.lenpen) )
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
align_dict = {} align_dict = {}
if args.unk_replace_dict != '': if args.unk_replace_dict != '':
assert args.interactive, "Unkown words replacing requires access to original source and is only" \ assert args.interactive, \
"supported in interactive mode" 'Unknown word replacement requires access to original source and is only supported in interactive mode'
with open(args.unk_replace_dict, 'r') as f: with open(args.unk_replace_dict, 'r') as f:
for line in f: for line in f:
l = line.split() l = line.split()
...@@ -80,27 +91,22 @@ def main(): ...@@ -80,27 +91,22 @@ def main():
hypo_tokens[i] = src_token hypo_tokens[i] = src_token
return ' '.join(hypo_tokens) return ' '.join(hypo_tokens)
if use_cuda:
translator.cuda()
bpe_symbol = '@@ ' if args.remove_bpe else None
def display_hypotheses(id, src, orig, ref, hypos): def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet: if args.quiet:
return return
id_str = '' if id is None else '-{}'.format(id) id_str = '' if id is None else '-{}'.format(id)
src_str = to_sentence(dataset.src_dict, src, bpe_symbol) src_str = dataset.src_dict.string(src, args.remove_bpe)
print('S{}\t{}'.format(id_str, src_str)) print('S{}\t{}'.format(id_str, src_str))
if orig is not None: if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip())) print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None: if ref is not None:
print('T{}\t{}'.format(id_str, to_sentence(dataset.dst_dict, ref, bpe_symbol, ref_unk=True))) print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, args.remove_bpe, escape_unk=True)))
for hypo in hypos: for hypo in hypos:
hypo_str = to_sentence(dataset.dst_dict, hypo['tokens'], bpe_symbol) hypo_str = dataset.dst_dict.string(hypo['tokens'], args.remove_bpe)
align_str = ' '.join(map(str, hypo['alignment'])) align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '': if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, unk_symbol(dataset.dst_dict)) hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
print('H{}\t{}\t{}'.format( print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str)) print('A{}\t{}'.format(id_str, align_str))
if args.interactive: if args.interactive:
...@@ -116,12 +122,12 @@ def main(): ...@@ -116,12 +122,12 @@ def main():
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)]) display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
else: else:
def maybe_remove_bpe(tokens): def maybe_remove_bpe(tokens, escape_unk=False):
"""Helper for removing BPE symbols from a hypothesis.""" """Helper for removing BPE symbols from a hypothesis."""
if not args.remove_bpe: if args.remove_bpe is None:
return tokens return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0 assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = to_sentence(dataset.dst_dict, tokens, bpe_symbol) hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True) return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score # Generate and compute BLEU score
...@@ -139,7 +145,7 @@ def main(): ...@@ -139,7 +145,7 @@ def main():
for id, src, ref, hypos in translations: for id, src, ref, hypos in translations:
ref = ref.int().cpu() ref = ref.int().cpu()
top_hypo = hypos[0]['tokens'].int().cpu() top_hypo = hypos[0]['tokens'].int().cpu()
scorer.add(maybe_remove_bpe(ref), maybe_remove_bpe(top_hypo)) scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo))
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)]) display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0)) wps_meter.update(src.size(0))
...@@ -151,25 +157,5 @@ def main(): ...@@ -151,25 +157,5 @@ def main():
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
def to_token(dict, i, runk):
return runk if i == dict.unk() else dict[i]
def unk_symbol(dict, ref_unk=False):
return '<{}>'.format(dict.unk_word) if ref_unk else dict.unk_word
def to_sentence(dict, tokens, bpe_symbol=None, ref_unk=False):
if torch.is_tensor(tokens) and tokens.dim() == 2:
sentences = [to_sentence(dict, token) for token in tokens]
return '\n'.join(sentences)
eos = dict.eos()
runk = unk_symbol(dict, ref_unk=ref_unk)
sent = ' '.join([to_token(dict, i, runk) for i in tokens if i != eos])
if bpe_symbol is not None:
sent = sent.replace(bpe_symbol, '')
return sent
if __name__ == '__main__': if __name__ == '__main__':
main() main()
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment