Unverified Commit 94dae690 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #77 from facebookresearch/oss-merge-internal

parents d74f200a 0a836276
......@@ -18,3 +18,6 @@ class FairseqDecoder(nn.Module):
def max_positions(self):
"""Maximum input length supported by the decoder."""
raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
......@@ -18,3 +18,6 @@ class FairseqEncoder(nn.Module):
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError
def upgrade_state_dict(self, state_dict):
return state_dict
......@@ -43,6 +43,11 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
......
......@@ -58,7 +58,7 @@ class FConvEncoder(FairseqEncoder):
self.projections = nn.ModuleList()
self.convolutions = nn.ModuleList()
for (out_channels, kernel_size) in convolutions:
pad = (kernel_size - 1) // 2
pad = (kernel_size - 1) / 2
self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None)
self.convolutions.append(
......@@ -154,6 +154,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1):
super().__init__()
self.register_buffer('version', torch.Tensor([2]))
self.dictionary = dictionary
self.dropout = dropout
......@@ -265,6 +266,16 @@ class FConvDecoder(FairseqIncrementalDecoder):
"""Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
nn.utils.remove_weight_norm(conv)
self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.
......@@ -307,7 +318,7 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
m.weight.data.normal_(mean=0, std=std)
m.bias.data.zero_()
return nn.utils.weight_norm(m)
return nn.utils.weight_norm(m, dim=2)
def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
......
......@@ -59,7 +59,7 @@ class ConvTBCFunction(Function):
kernel_size = weight_size[0]
output = input.new(
input_size[0] - kernel_size + 1 + pad * 2,
input_size[0] - kernel_size + 1 + int(pad * 2),
input_size[1],
weight_size[2])
......
......@@ -11,10 +11,11 @@ Train a network on multiple GPUs using multiprocessing.
"""
from itertools import cycle, islice
import math
import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import nccl, utils
from fairseq import meters, nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG
......@@ -67,39 +68,61 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize optimizer
# initialize optimizer and LR scheduler
self.args.lr = list(map(float, self.args.lr.split(',')))
self.optimizer = self._build_optimizer()
self.loss = None
# initialize LR scheduler
self.lr_scheduler = self._build_lr_scheduler()
self.loss = None
self._max_bsz_seen = 0
def _build_optimizer(self):
# When resuming training from a checkpoint, we load the old optimizer
# state that includes things like learning rate, momentum factor, etc.
# We use this dictionary to override values stored in the checkpoint,
# e.g., we might prefer the values specified on the command line.
self._override_optim_state = {}
if self.args.optimizer == 'adagrad':
return torch.optim.Adagrad(self.model.parameters(), lr=self.args.lr,
weight_decay=self.args.weight_decay)
self._override_optim_state = {
'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adagrad(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'adam':
return torch.optim.Adam(self.model.parameters(), lr=self.args.lr,
betas=eval(self.args.adam_betas),
weight_decay=self.args.weight_decay)
self._override_optim_state = {
'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay,
}
return torch.optim.Adam(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'nag':
return NAG(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self._override_optim_state = {
'lr': self.args.lr[0],
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return NAG(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'sgd':
return torch.optim.SGD(self.model.parameters(), lr=self.args.lr,
momentum=self.args.momentum,
weight_decay=self.args.weight_decay)
self._override_optim_state = {
'lr': self.args.lr[0],
'momentum': self.args.momentum,
'weight_decay': self.args.weight_decay,
}
return torch.optim.SGD(self.model.parameters(), **self._override_optim_state)
else:
raise ValueError('Unknown optimizer: {}'.format(self.args.optimizer))
def _build_lr_scheduler(self):
if self.args.force_anneal > 0:
if len(self.args.lr) > 1 or self.args.force_anneal > 0:
lrs = self.args.lr
def anneal(e):
if e < self.args.force_anneal:
return 1
# use fixed LR schedule
next_lr = lrs[min(e, len(lrs) - 1)]
else:
return self.args.lrshrink ** (e + 1 - self.args.force_anneal)
next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR
lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None
else:
......@@ -134,9 +157,24 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return extra_state
def _async_load_checkpoint(self, rank, device_id, filename):
extra_state, self._optim_history = utils.load_state(
filename, self.model, self.criterion, self.optimizer,
self.lr_scheduler, cuda_device=device_id)
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, cuda_device=device_id)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self.optimizer = self._build_optimizer()
self.lr_scheduler = self._build_lr_scheduler()
# only load optimizer and lr_scheduler if they match the checkpoint
last_optim = self._optim_history[-1]
if last_optim['criterion_name'] == self.criterion.__class__.__name__:
self.optimizer.load_state_dict(last_optim_state)
self.lr_scheduler.best = last_optim['best_loss']
# override learning rate, momentum, etc. with latest values
for group in self.optimizer.param_groups:
group.update(self._override_optim_state)
return extra_state
def set_seed(self, seed):
......@@ -161,14 +199,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
# forward pass
sample_sizes, logging_outputs = Future.gen_tuple_list([
sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward')
for rank in range(self.num_replicas)
])
# backward pass, all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norms = Future.gen_list([
grad_norms, ooms_bwd = Future.gen_tuple_list([
self.call_async(rank, '_async_backward_and_opt', grad_denom=grad_denom)
for rank in range(self.num_replicas)
])
......@@ -176,6 +214,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
logging_output['gnorm'] = grad_norms[0] # log the gradient norm
logging_output['oom'] = sum(ooms_fwd) + sum(ooms_bwd)
return logging_output
......@@ -186,34 +225,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train()
self.optimizer.zero_grad()
if self._sample is None:
return 0, {}
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
# calculate loss and sample size
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
self.loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return sample_size, logging_output
return sample_size, logging_output, oom
def _async_backward_and_opt(self, rank, device_id, grad_denom):
oom = False
if self.loss is not None:
# backward pass
self.loss.backward()
# get model parameters as a flattened (contiguous) tensor
flat_grads = self._flat_model_grads()
# all-reduce grads
nccl.all_reduce(flat_grads)
try:
# backward pass
self.loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
oom = True
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else:
raise e
# normalize grads
if grad_denom != 0:
flat_grads.div_(grad_denom)
# all-reduce grads and rescale by grad_denom
self._all_reduce_and_rescale_grads(grad_denom)
# clip grads
grad_norm = self._clip_grads_(flat_grads, self.args.clip_norm)
# copy reduced grads back
self._set_model_grads_(flat_grads)
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
# take an optimization step
self.optimizer.step()
......@@ -221,41 +270,49 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# reset loss
self.loss = None
return grad_norm
def _model_grads(self):
return [p.grad for p in self.model.parameters() if p.requires_grad]
def _flat_model_grads(self):
grads = self._model_grads()
if not hasattr(self, '_flat_grads'):
num_params = sum(g.data.numel() for g in grads)
self._flat_grads = grads[0].data.new(num_params)
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
self._flat_grads[offset:offset+numel].copy_(grad)
offset += numel
return self._flat_grads
def _set_model_grads_(self, flat_grads):
grads = self._model_grads()
offset = 0
for grad in grads:
grad = grad.data.view(-1)
numel = grad.numel()
grad.copy_(flat_grads[offset:offset+numel])
offset += numel
assert offset == flat_grads.numel()
def _clip_grads_(self, flat_grads, clipv):
"""nn.utils.clip_grad_norm for flattened (contiguous) tensors."""
norm = flat_grads.norm()
if clipv > 0 and norm > clipv:
coef = max(norm, 1e-6) / clipv
flat_grads.div_(coef)
return norm
return grad_norm, oom
def _all_reduce_and_rescale_grads(self, grad_denom, buffer_size=10485760):
"""All-reduce and rescale gradients in chunks of the specified size."""
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
buffer_t = grads[0].new(math.ceil(buffer_size / grads[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy grads into buffer_t
offset = 0
for g in buffer:
numel = g.numel()
buffer_t[offset:offset+numel].copy_(g.view(-1))
offset += numel
# all-reduce and rescale
nccl.all_reduce(buffer_t[:offset])
buffer_t.div_(grad_denom)
# copy all-reduced buffer back into grads
offset = 0
for g in buffer:
numel = g.numel()
g.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for g in grads:
sz = g.numel() * g.element_size()
if sz > buffer_size:
# grad is bigger than buffer, all-reduce and rescale directly
nccl.all_reduce(g)
g.div_(grad_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [g]
filled = sz
else:
# add grad to buffer
buffer.append(g)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def valid_step(self, samples):
"""Do forward pass in parallel."""
......@@ -263,10 +320,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._scatter_samples(samples, volatile=True)
# forward pass
_sample_sizes, logging_outputs = Future.gen_tuple_list([
_sample_sizes, logging_outputs, ooms_fwd = Future.gen_tuple_list([
self.call_async(rank, '_async_forward', eval=True)
for rank in range(self.num_replicas)
])
assert sum(ooms_fwd) == 0
# aggregate logging output
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
......@@ -314,4 +372,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
if sample is None:
self._sample = None
else:
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
......@@ -49,8 +49,8 @@ def add_optimization_args(parser):
group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=MultiprocessingTrainer.OPTIMIZERS,
help='optimizer ({})'.format(', '.join(MultiprocessingTrainer.OPTIMIZERS)))
group.add_argument('--lr', '--learning-rate', default=0.25, type=float, metavar='LR',
help='initial learning rate')
group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR1,LR2,...,LRn',
help='learning rate for the first n epochs with all epochs >n using LRn')
group.add_argument('--min-lr', metavar='LR', default=1e-5, type=float,
help='minimum learning rate')
group.add_argument('--force-anneal', '--fa', default=0, type=int, metavar='N',
......
......@@ -83,9 +83,9 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, optim_
torch_persistent_save(state_dict, filename)
def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=None):
def load_model_state(filename, model, cuda_device=None):
if not os.path.exists(filename):
return None, []
return None, [], None
if cuda_device is None:
state = torch.load(filename)
else:
......@@ -94,18 +94,16 @@ def load_state(filename, model, criterion, optimizer, lr_scheduler, cuda_device=
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
# load model parameters
model.load_state_dict(state['model'])
# only load optimizer and lr_scheduler if they match with the checkpoint
optim_history = state['optimizer_history']
last_optim = optim_history[-1]
if last_optim['criterion_name'] == criterion.__class__.__name__:
optimizer.load_state_dict(state['last_optimizer_state'])
lr_scheduler.best = last_optim['best_loss']
try:
model.load_state_dict(state['model'])
except:
raise Exception('Cannot load model parameters from checkpoint, '
'please ensure that the architectures match')
return state['extra_state'], optim_history
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state']
def _upgrade_state_dict(state):
......@@ -164,6 +162,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
ensemble = []
for state in states:
model = build_model(args, src_dict, dst_dict)
state['model'] = model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble, args
......
......@@ -53,18 +53,18 @@ def main():
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
args.num_gpus = torch.cuda.device_count()
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences))
args.num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
......@@ -102,11 +102,11 @@ def main():
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)
train(args, epoch, batch_offset, trainer, dataset, max_positions_train)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset)
if k == 0:
if not args.no_save:
# save checkpoint
......@@ -130,7 +130,7 @@ def get_perplexity(loss):
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
def train(args, epoch, batch_offset, trainer, dataset, max_positions):
"""Train the model for one epoch."""
seed = args.seed + epoch
......@@ -152,7 +152,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
lr = trainer.get_lr()
with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset):
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
......@@ -222,7 +222,7 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader(
......@@ -236,7 +236,7 @@ def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment