Commit 7ee1d284 authored by Myle Ott's avatar Myle Ott
Browse files

Add FP16 support

parent 73a87327
......@@ -12,7 +12,9 @@ import math
import numbers
import numpy as np
import os
import torch
from torch.autograd import Variable
import torch.utils.data
from fairseq.dictionary import Dictionary
......@@ -435,3 +437,21 @@ def numpy_seed(seed):
yield
finally:
np.random.set_state(state)
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len))
bsz = (bsz // 8) * 8
assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict)
dummy_batch = {}
dummy_batch['id'] = Variable(torch.arange(bsz).long().cuda())
dummy_batch['ntokens'] = tgt_len * bsz
dummy_batch['target'] = Variable(torch.Tensor(bsz, tgt_len).uniform_(pad_idx + 1, dst_vocab - 1).long().cuda())
input = {}
input['prev_output_tokens'] = Variable(dummy_batch['target'].data.clone())
input['src_lengths'] = Variable(torch.LongTensor(bsz).fill_(src_len).cuda())
input['src_tokens'] = Variable(torch.Tensor(bsz, src_len).uniform_(pad_idx + 1, src_vocab - 1).long().cuda())
dummy_batch['net_input'] = input
return dummy_batch
......@@ -53,58 +53,6 @@ def suppress_output():
__builtin__.print = print
def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size is in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import math
import torch
from fairseq import optim
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float('inf') or grad_norm != grad_norm:
return True
return False
class FP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, model, criterion):
super().__init__(args, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
def _build_optimizer(self):
# create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self.optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self):
# zero both the FP16 and FP32 grads
self.model.zero_grad() # FP16
self.optimizer.zero_grad() # FP32
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale
# all-reduce and rescale gradients
grad_norm = super()._all_reduce_and_rescale(grad_denom)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def _get_flat_grads(self, out=None):
if out is None:
out = self.fp32_params.grad
return super()._get_flat_grads(out)
def _set_flat_grads(self, new_grads):
# no-op
assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr()
def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
# copy FP32 params back into FP16 model
offset = 0
for p in self.model.parameters():
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
......@@ -21,7 +21,7 @@ class FairseqDecoder(nn.Module):
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0]
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
......
......@@ -155,6 +155,9 @@ def add_optimization_args(parser):
' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i')
has_tensor_cores = torch.cuda.device_count() > 0 and torch.cuda.get_device_capability(0)[0] >= 7
group.add_argument('--fp16', action='store_true', default=has_tensor_cores,
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT',
......
......@@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
Train a network across multiple GPUs.
"""
from collections import defaultdict, OrderedDict
......@@ -20,11 +20,11 @@ from fairseq.optim import lr_scheduler
class Trainer(object):
"""Main class for multi-GPU training.
"""Main class for data parallel training.
Each GPU has a full copy of the model and is assigned to its own Python
process. Gradients are accumulated with torch.distributed.all_reduce and all
model replicas are updated synchronously after each batch.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def __init__(self, args, model, criterion):
......@@ -39,8 +39,7 @@ class Trainer(object):
self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self._build_optimizer()
# initialize meters
self.meters = OrderedDict()
......@@ -55,12 +54,17 @@ class Trainer(object):
self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds
self._buffered_stats = defaultdict(lambda: [])
self._max_bsz_seen = 0
self._flat_grads = None
self._num_updates = 0
self._optim_history = None
def _build_optimizer(self):
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint
......@@ -69,13 +73,12 @@ class Trainer(object):
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state(
filename, self.model, cuda_device=torch.cuda.current_device())
extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.model)
if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
self._build_optimizer()
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
......@@ -105,7 +108,7 @@ class Trainer(object):
# update parameters
if update_params:
# gather logging outputs from all GPUs
# gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
......@@ -124,28 +127,34 @@ class Trainer(object):
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norm = self._opt(grad_denom)
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
self._buffered_stats.clear()
try:
# all-reduce and rescale gradients, then take an optimization step
grad_norm = self._all_reduce_and_rescale(grad_denom)
self._opt()
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))
self.clear_buffered_stats()
return agg_logging_output
else:
......@@ -157,7 +166,6 @@ class Trainer(object):
self.model.eval()
else:
self.model.train()
loss = None
sample_size = 0
logging_output = {
......@@ -176,11 +184,8 @@ class Trainer(object):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise e
return loss, sample_size, logging_output, oom
def _backward(self, loss):
......@@ -193,39 +198,66 @@ class Trainer(object):
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
self.optimizer.zero_grad()
self.zero_grad()
else:
raise e
return oom
def _opt(self, grad_denom):
# all-reduce grads and rescale by grad_denom
def _all_reduce_and_rescale(self, grad_denom):
# flatten grads into a single buffer and all-reduce
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
if self.args.distributed_world_size > 1:
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
distributed_utils.all_reduce_and_rescale_tensors(grads, grad_denom)
else:
for p in self.model.parameters():
if p.requires_grad:
p.grad.data.div_(grad_denom)
torch.distributed.all_reduce(flat_grads)
# clip grads
if self.args.clip_norm > 0:
grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm))
else:
grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters()))
# rescale and clip gradients
flat_grads.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
# copy grads back into model parameters
self._set_flat_grads(flat_grads)
return grad_norm
def _get_grads(self):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False')
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
grads = self._get_grads()
if out is None:
grads_size = sum(g.numel() for g in grads)
out = grads[0].new(grads_size).zero_()
offset = 0
for g in grads:
numel = g.numel()
out[offset:offset+numel].copy_(g.view(-1))
offset += numel
return out[:offset]
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step
self.optimizer.step()
self.optimizer.zero_grad()
self.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
return grad_norm
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
......@@ -258,6 +290,18 @@ class Trainer(object):
return agg_logging_output
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss)
......@@ -283,9 +327,4 @@ class Trainer(object):
def _prepare_sample(self, sample, volatile):
if sample is None or len(sample) == 0:
return None
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
return utils.make_variable(sample, volatile=volatile, cuda=True)
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import defaultdict
from collections import defaultdict, OrderedDict
import contextlib
import logging
import os
......@@ -25,6 +25,20 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None):
if optim_history is None:
......@@ -33,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {}
state_dict = {
'args': args,
'model': model.state_dict(),
'model': convert_state_dict_type(model.state_dict()),
'optimizer_history': optim_history + [
{
'criterion_name': criterion.__class__.__name__,
......@@ -42,22 +56,16 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
'num_updates': num_updates,
}
],
'last_optimizer_state': optimizer.state_dict(),
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state,
}
torch_persistent_save(state_dict, filename)
def load_model_state(filename, model, cuda_device=None):
def load_model_state(filename, model):
if not os.path.exists(filename):
return None, [], None
if cuda_device is None:
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = torch.load(filename)
state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model'])
......@@ -377,6 +385,14 @@ def item(tensor):
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
......@@ -44,7 +44,7 @@ def average_checkpoints(inputs):
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
params_dict[k].append(model_params[k])
params_dict[k].append(model_params[k].float())
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
......
......@@ -13,8 +13,9 @@ import math
import torch
from fairseq import criterions, data, models, options, progress_bar
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
def main(args):
......@@ -48,7 +49,10 @@ def main(args):
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
# Build trainer
trainer = Trainer(args, model, criterion)
if args.fp16:
trainer = FP16Trainer(args, model, criterion)
else:
trainer = Trainer(args, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
......@@ -84,6 +88,10 @@ def main(args):
_ = next(train_dataloader)
epoch += 1
# Send a dummy batch to warm the caching allocator
dummy_batch = data.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
......@@ -153,7 +161,7 @@ def train(args, trainer, itr, epoch):
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss']:
if k in ['loss', 'nll_loss', 'sample_size']:
continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
......@@ -194,6 +202,9 @@ def get_training_stats(trainer):
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
return stats
......@@ -234,7 +245,7 @@ def validate(args, trainer, dataset, subset, epoch):
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss']:
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment