Commit 7ee1d284 authored by Myle Ott's avatar Myle Ott
Browse files

Add FP16 support

parent 73a87327
...@@ -12,7 +12,9 @@ import math ...@@ -12,7 +12,9 @@ import math
import numbers import numbers
import numpy as np import numpy as np
import os import os
import torch import torch
from torch.autograd import Variable
import torch.utils.data import torch.utils.data
from fairseq.dictionary import Dictionary from fairseq.dictionary import Dictionary
...@@ -435,3 +437,21 @@ def numpy_seed(seed): ...@@ -435,3 +437,21 @@ def numpy_seed(seed):
yield yield
finally: finally:
np.random.set_state(state) np.random.set_state(state)
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len))
bsz = (bsz // 8) * 8
assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict)
dummy_batch = {}
dummy_batch['id'] = Variable(torch.arange(bsz).long().cuda())
dummy_batch['ntokens'] = tgt_len * bsz
dummy_batch['target'] = Variable(torch.Tensor(bsz, tgt_len).uniform_(pad_idx + 1, dst_vocab - 1).long().cuda())
input = {}
input['prev_output_tokens'] = Variable(dummy_batch['target'].data.clone())
input['src_lengths'] = Variable(torch.LongTensor(bsz).fill_(src_len).cuda())
input['src_tokens'] = Variable(torch.Tensor(bsz, src_len).uniform_(pad_idx + 1, src_vocab - 1).long().cuda())
dummy_batch['net_input'] = input
return dummy_batch
...@@ -53,58 +53,6 @@ def suppress_output(): ...@@ -53,58 +53,6 @@ def suppress_output():
__builtin__.print = print __builtin__.print = print
def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size is in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096): def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list.""" """Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import math
import torch
from fairseq import optim
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float('inf') or grad_norm != grad_norm:
return True
return False
class FP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, model, criterion):
super().__init__(args, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
def _build_optimizer(self):
# create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self.optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self):
# zero both the FP16 and FP32 grads
self.model.zero_grad() # FP16
self.optimizer.zero_grad() # FP32
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale
# all-reduce and rescale gradients
grad_norm = super()._all_reduce_and_rescale(grad_denom)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def _get_flat_grads(self, out=None):
if out is None:
out = self.fp32_params.grad
return super()._get_flat_grads(out)
def _set_flat_grads(self, new_grads):
# no-op
assert new_grads.data_ptr() == self.fp32_params.grad.data.data_ptr()
def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
# copy FP32 params back into FP16 model
offset = 0
for p in self.model.parameters():
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
...@@ -21,7 +21,7 @@ class FairseqDecoder(nn.Module): ...@@ -21,7 +21,7 @@ class FairseqDecoder(nn.Module):
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0] logits = net_output[0].float()
if log_probs: if log_probs:
return F.log_softmax(logits, dim=-1) return F.log_softmax(logits, dim=-1)
else: else:
......
...@@ -155,6 +155,9 @@ def add_optimization_args(parser): ...@@ -155,6 +155,9 @@ def add_optimization_args(parser):
' (default is to normalize by number of tokens)') ' (default is to normalize by number of tokens)')
group.add_argument('--update-freq', default='1', metavar='N', group.add_argument('--update-freq', default='1', metavar='N',
help='update parameters every N_i batches, when in epoch i') help='update parameters every N_i batches, when in epoch i')
has_tensor_cores = torch.cuda.device_count() > 0 and torch.cuda.get_device_capability(0)[0] >= 7
group.add_argument('--fp16', action='store_true', default=has_tensor_cores,
help='use FP16 during training')
# Optimizer definitions can be found under fairseq/optim/ # Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
""" """
Train a network on multiple GPUs. Train a network across multiple GPUs.
""" """
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
...@@ -20,11 +20,11 @@ from fairseq.optim import lr_scheduler ...@@ -20,11 +20,11 @@ from fairseq.optim import lr_scheduler
class Trainer(object): class Trainer(object):
"""Main class for multi-GPU training. """Main class for data parallel training.
Each GPU has a full copy of the model and is assigned to its own Python This class supports data parallel training, where multiple workers each
process. Gradients are accumulated with torch.distributed.all_reduce and all have a full model replica and gradients are accumulated synchronously via
model replicas are updated synchronously after each batch. torch.distributed.all_reduce.
""" """
def __init__(self, args, model, criterion): def __init__(self, args, model, criterion):
...@@ -39,8 +39,7 @@ class Trainer(object): ...@@ -39,8 +39,7 @@ class Trainer(object):
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler # initialize optimizer and LR scheduler
self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._build_optimizer()
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
# initialize meters # initialize meters
self.meters = OrderedDict() self.meters = OrderedDict()
...@@ -55,12 +54,17 @@ class Trainer(object): ...@@ -55,12 +54,17 @@ class Trainer(object):
self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters['oom'] = AverageMeter() # out of memory
self.meters['wall'] = TimeMeter() # wall time in seconds
self._buffered_stats = defaultdict(lambda: []) self._buffered_stats = defaultdict(lambda: [])
self._max_bsz_seen = 0 self._flat_grads = None
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
def _build_optimizer(self):
self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
...@@ -69,13 +73,12 @@ class Trainer(object): ...@@ -69,13 +73,12 @@ class Trainer(object):
def load_checkpoint(self, filename): def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = utils.load_model_state( extra_state, self._optim_history, last_optim_state = \
filename, self.model, cuda_device=torch.cuda.current_device()) utils.load_model_state(filename, self.model)
if last_optim_state is not None: if last_optim_state is not None:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._build_optimizer()
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
...@@ -105,7 +108,7 @@ class Trainer(object): ...@@ -105,7 +108,7 @@ class Trainer(object):
# update parameters # update parameters
if update_params: if update_params:
# gather logging outputs from all GPUs # gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes'] sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs'] logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_fwd = self._buffered_stats['ooms_fwd']
...@@ -124,28 +127,34 @@ class Trainer(object): ...@@ -124,28 +127,34 @@ class Trainer(object):
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norm = self._opt(grad_denom)
try:
# update meters # all-reduce and rescale gradients, then take an optimization step
self.meters['wps'].update(ntokens) grad_norm = self._all_reduce_and_rescale(grad_denom)
self.meters['ups'].update(1.) self._opt()
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences) # update meters
self.meters['gnorm'].update(grad_norm) self.meters['wps'].update(ntokens)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.) self.meters['ups'].update(1.)
self.meters['oom'].update(ooms_fwd + ooms_bwd) self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
# update loss meters for training if grad_norm is not None:
if 'loss' in agg_logging_output: self.meters['gnorm'].update(grad_norm)
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom) self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
# criterions can optionally log the NLL loss too self.meters['oom'].update(ooms_fwd + ooms_bwd)
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) # update loss meters for training
if 'loss' in agg_logging_output:
self._buffered_stats.clear() self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e))
self.clear_buffered_stats()
return agg_logging_output return agg_logging_output
else: else:
...@@ -157,7 +166,6 @@ class Trainer(object): ...@@ -157,7 +166,6 @@ class Trainer(object):
self.model.eval() self.model.eval()
else: else:
self.model.train() self.model.train()
loss = None loss = None
sample_size = 0 sample_size = 0
logging_output = { logging_output = {
...@@ -176,11 +184,8 @@ class Trainer(object): ...@@ -176,11 +184,8 @@ class Trainer(object):
print('| WARNING: ran out of memory, skipping batch') print('| WARNING: ran out of memory, skipping batch')
oom = 1 oom = 1
loss = None loss = None
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else: else:
raise e raise e
return loss, sample_size, logging_output, oom return loss, sample_size, logging_output, oom
def _backward(self, loss): def _backward(self, loss):
...@@ -193,39 +198,66 @@ class Trainer(object): ...@@ -193,39 +198,66 @@ class Trainer(object):
if 'out of memory' in str(e): if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch') print('| WARNING: ran out of memory, skipping batch')
oom = 1 oom = 1
if hasattr(torch.cuda, 'empty_cache'): self.zero_grad()
torch.cuda.empty_cache()
self.optimizer.zero_grad()
else: else:
raise e raise e
return oom return oom
def _opt(self, grad_denom): def _all_reduce_and_rescale(self, grad_denom):
# all-reduce grads and rescale by grad_denom # flatten grads into a single buffer and all-reduce
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad] torch.distributed.all_reduce(flat_grads)
distributed_utils.all_reduce_and_rescale_tensors(grads, grad_denom)
else:
for p in self.model.parameters():
if p.requires_grad:
p.grad.data.div_(grad_denom)
# clip grads # rescale and clip gradients
if self.args.clip_norm > 0: flat_grads.div_(grad_denom)
grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)) grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
else:
grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters())) # copy grads back into model parameters
self._set_flat_grads(flat_grads)
return grad_norm
def _get_grads(self):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False')
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
grads = self._get_grads()
if out is None:
grads_size = sum(g.numel() for g in grads)
out = grads[0].new(grads_size).zero_()
offset = 0
for g in grads:
numel = g.numel()
out[offset:offset+numel].copy_(g.view(-1))
offset += numel
return out[:offset]
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.zero_grad()
self._num_updates += 1 self._num_updates += 1
# update learning rate # update learning rate
self.lr_scheduler.step_update(self._num_updates) self.lr_scheduler.step_update(self._num_updates)
return grad_norm
def valid_step(self, sample): def valid_step(self, sample):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
...@@ -258,6 +290,18 @@ class Trainer(object): ...@@ -258,6 +290,18 @@ class Trainer(object):
return agg_logging_output return agg_logging_output
def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False)
self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self):
self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None): def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss.""" """Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss) return self.lr_scheduler.step(epoch, val_loss)
...@@ -283,9 +327,4 @@ class Trainer(object): ...@@ -283,9 +327,4 @@ class Trainer(object):
def _prepare_sample(self, sample, volatile): def _prepare_sample(self, sample, volatile):
if sample is None or len(sample) == 0: if sample is None or len(sample) == 0:
return None return None
if hasattr(torch.cuda, 'empty_cache'):
# clear the caching allocator if this is the largest sample we've seen
if sample['target'].size(0) > self._max_bsz_seen:
self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache()
return utils.make_variable(sample, volatile=volatile, cuda=True) return utils.make_variable(sample, volatile=volatile, cuda=True)
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict from collections import defaultdict, OrderedDict
import contextlib import contextlib
import logging import logging
import os import os
...@@ -25,6 +25,20 @@ def torch_persistent_save(*args, **kwargs): ...@@ -25,6 +25,20 @@ def torch_persistent_save(*args, **kwargs):
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor):
if isinstance(state_dict, dict):
cpu_dict = OrderedDict()
for k, v in state_dict.items():
cpu_dict[k] = convert_state_dict_type(v)
return cpu_dict
elif isinstance(state_dict, list):
return [convert_state_dict_type(v) for v in state_dict]
elif torch.is_tensor(state_dict):
return state_dict.type(ttype)
else:
return state_dict
def save_state(filename, args, model, criterion, optimizer, lr_scheduler, def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
num_updates, optim_history=None, extra_state=None): num_updates, optim_history=None, extra_state=None):
if optim_history is None: if optim_history is None:
...@@ -33,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -33,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state = {} extra_state = {}
state_dict = { state_dict = {
'args': args, 'args': args,
'model': model.state_dict(), 'model': convert_state_dict_type(model.state_dict()),
'optimizer_history': optim_history + [ 'optimizer_history': optim_history + [
{ {
'criterion_name': criterion.__class__.__name__, 'criterion_name': criterion.__class__.__name__,
...@@ -42,22 +56,16 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler, ...@@ -42,22 +56,16 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
'num_updates': num_updates, 'num_updates': num_updates,
} }
], ],
'last_optimizer_state': optimizer.state_dict(), 'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()),
'extra_state': extra_state, 'extra_state': extra_state,
} }
torch_persistent_save(state_dict, filename) torch_persistent_save(state_dict, filename)
def load_model_state(filename, model, cuda_device=None): def load_model_state(filename, model):
if not os.path.exists(filename): if not os.path.exists(filename):
return None, [], None return None, [], None
if cuda_device is None: state = torch.load(filename)
state = torch.load(filename)
else:
state = torch.load(
filename,
map_location=lambda s, l: default_restore_location(s, 'cuda:{}'.format(cuda_device))
)
state = _upgrade_state_dict(state) state = _upgrade_state_dict(state)
state['model'] = model.upgrade_state_dict(state['model']) state['model'] = model.upgrade_state_dict(state['model'])
...@@ -377,6 +385,14 @@ def item(tensor): ...@@ -377,6 +385,14 @@ def item(tensor):
return tensor return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf.""" """FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t) return t.float().fill_(float('-inf')).type_as(t)
...@@ -44,7 +44,7 @@ def average_checkpoints(inputs): ...@@ -44,7 +44,7 @@ def average_checkpoints(inputs):
for k in params_keys: for k in params_keys:
if k not in params_dict: if k not in params_dict:
params_dict[k] = [] params_dict[k] = []
params_dict[k].append(model_params[k]) params_dict[k].append(model_params[k].float())
averaged_params = collections.OrderedDict() averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor. # v should be a list of torch Tensor.
......
...@@ -13,8 +13,9 @@ import math ...@@ -13,8 +13,9 @@ import math
import torch import torch
from fairseq import criterions, data, models, options, progress_bar from fairseq import criterions, data, models, options, progress_bar
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
def main(args): def main(args):
...@@ -48,7 +49,10 @@ def main(args): ...@@ -48,7 +49,10 @@ def main(args):
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters()))) print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
# Build trainer # Build trainer
trainer = Trainer(args, model, criterion) if args.fp16:
trainer = FP16Trainer(args, model, criterion)
else:
trainer = Trainer(args, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size)) print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens, args.max_tokens,
...@@ -84,6 +88,10 @@ def main(args): ...@@ -84,6 +88,10 @@ def main(args):
_ = next(train_dataloader) _ = next(train_dataloader)
epoch += 1 epoch += 1
# Send a dummy batch to warm the caching allocator
dummy_batch = data.get_dummy_batch(args.max_tokens, dataset.src_dict, dataset.dst_dict)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
...@@ -153,7 +161,7 @@ def train(args, trainer, itr, epoch): ...@@ -153,7 +161,7 @@ def train(args, trainer, itr, epoch):
# log mid-epoch stats # log mid-epoch stats
stats = get_training_stats(trainer) stats = get_training_stats(trainer)
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss']: if k in ['loss', 'nll_loss', 'sample_size']:
continue # these are already logged above continue # these are already logged above
if 'loss' in k: if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size']) extra_meters[k].update(v, log_output['sample_size'])
...@@ -194,6 +202,9 @@ def get_training_stats(trainer): ...@@ -194,6 +202,9 @@ def get_training_stats(trainer):
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg) stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg) stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg stats['oom'] = trainer.get_meter('oom').avg
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
return stats return stats
...@@ -234,7 +245,7 @@ def validate(args, trainer, dataset, subset, epoch): ...@@ -234,7 +245,7 @@ def validate(args, trainer, dataset, subset, epoch):
# log mid-validation stats # log mid-validation stats
stats = get_valid_stats(trainer) stats = get_valid_stats(trainer)
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss']: if k in ['loss', 'nll_loss', 'sample_size']:
continue continue
extra_meters[k].update(v) extra_meters[k].update(v)
stats[k] = extra_meters[k].avg stats[k] = extra_meters[k].avg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment