Commit 1082ba35 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Switch to DistributedDataParallelC10d and bump version 0.5.0 -> 0.6.0

- no more FP16Trainer, we just have an FP16Optimizer wrapper
- most of the distributed code is moved to a new wrapper class called DistributedFairseqModel, which behaves like DistributedDataParallel and a FairseqModel at the same time
- Trainer now requires an extra dummy_batch argument at initialization, which we do fwd/bwd on when there's an uneven number of batches per worker. We hide the gradients from these dummy batches by multiplying the loss by 0
- Trainer.train_step now takes a list of samples, which will allow cleaner --update-freq
parent 311d2c6c
...@@ -30,7 +30,7 @@ def main(args): ...@@ -30,7 +30,7 @@ def main(args):
raise e raise e
except FileNotFoundError as e: # Slurm is not installed except FileNotFoundError as e: # Slurm is not installed
pass pass
if args.distributed_init_method is None: if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port ' raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training') 'must be specified for distributed training')
......
...@@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' ...@@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.5.0' version = '0.6.0'
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '0.5.0' release = '0.6.0'
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
......
...@@ -36,5 +36,7 @@ Iterators ...@@ -36,5 +36,7 @@ Iterators
:members: :members:
.. autoclass:: fairseq.data.EpochBatchIterator .. autoclass:: fairseq.data.EpochBatchIterator
:members: :members:
.. autoclass:: fairseq.data.GroupedIterator
:members:
.. autoclass:: fairseq.data.ShardedIterator .. autoclass:: fairseq.data.ShardedIterator
:members: :members:
...@@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion):
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = { agg_output = {
'loss': loss_sum / sample_size / math.log(2), 'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
if sample_size != ntokens: if sample_size != ntokens:
......
...@@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion):
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = { agg_output = {
'loss': loss_sum / sample_size / math.log(2), 'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
if sample_size != ntokens: if sample_size != ntokens:
......
...@@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss return loss, nll_loss
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return { return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
...@@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset ...@@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [ __all__ = [
'CountingIterator', 'CountingIterator',
'Dictionary', 'Dictionary',
'EpochBatchIterator', 'EpochBatchIterator',
'FairseqDataset', 'FairseqDataset',
'GroupedIterator',
'IndexedDataset', 'IndexedDataset',
'IndexedInMemoryDataset', 'IndexedInMemoryDataset',
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'MonolingualDataset', 'MonolingualDataset',
'TokenBlockDataset',
'ShardedIterator', 'ShardedIterator',
'TokenBlockDataset',
] ]
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import itertools
import math
import numpy as np import numpy as np
import torch import torch
...@@ -150,6 +151,36 @@ class EpochBatchIterator(object): ...@@ -150,6 +151,36 @@ class EpochBatchIterator(object):
)) ))
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
"""
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.itr = iter(iterable)
self.chunk_size = chunk_size
def __len__(self):
return self._len
def __iter__(self):
return self
def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk
class ShardedIterator(object): class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length. """A sharded wrapper around an iterable, padded to length.
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
import pickle import pickle
import torch.distributed import torch
from torch import distributed
from torch.distributed import group
from fairseq import utils from fairseq import utils
...@@ -16,22 +18,39 @@ def is_master(args): ...@@ -16,22 +18,39 @@ def is_master(args):
return args.distributed_rank == 0 return args.distributed_rank == 0
_use_c10d = [None]
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if _use_c10d[0] is None:
_use_c10d[0] = not args.no_c10d
if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
_use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')
print('| distributed init (rank {}): {}'.format( print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True) args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group( if _use_c10d[0]:
backend=args.distributed_backend, init_method=args.distributed_init_method, distributed.c10d.init_process_group(
world_size=args.distributed_world_size, rank=args.distributed_rank) backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
else: else:
torch.distributed.init_process_group( distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method, backend=args.distributed_backend,
world_size=args.distributed_world_size) init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
args.distributed_rank = torch.distributed.get_rank()
if not is_master(args): if not is_master(args):
suppress_output() suppress_output()
...@@ -52,34 +71,76 @@ def suppress_output(): ...@@ -52,34 +71,76 @@ def suppress_output():
__builtin__.print = print __builtin__.print = print
def all_gather_list(data, max_size=16384): def get_rank():
"""Gathers arbitrary data from all nodes into a list.""" if _use_c10d[0]:
world_size = torch.distributed.get_world_size() return distributed.c10d.get_rank()
if not hasattr(all_gather_list, '_in_buffer') or \ else:
max_size != all_gather_list._in_buffer.size(): return distributed.get_rank()
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size) def get_world_size():
for i in range(world_size) if _use_c10d[0]:
] return distributed.c10d.get_world_size()
in_buffer = all_gather_list._in_buffer else:
out_buffers = all_gather_list._out_buffers return distributed.get_world_size()
def get_default_group():
if _use_c10d[0]:
return distributed.c10d.group.WORLD
else:
return distributed.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group)
else:
return distributed.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
buffer = all_gather_list._buffer
buffer.zero_()
enc = pickle.dumps(data) enc = pickle.dumps(data)
enc_size = len(enc) enc_size = len(enc)
if enc_size + 2 > max_size: if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256 assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda()) buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[1] = enc_size % 255
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
all_reduce(buffer, group=group)
result = [] result = []
for i in range(world_size): for i in range(world_size):
out_buffer = out_buffers[i] out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
if size > 0:
result.append( result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist())) pickle.loads(bytes(out_buffer[2:size+2].tolist()))
) )
......
...@@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 ...@@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401 from .composite_encoder import CompositeEncoder # noqa: F401
from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.distributed import c10d
from torch.nn import parallel
from . import BaseFairseqModel
class DistributedFairseqModel(BaseFairseqModel):
"""
A wrapper around a :class:`BaseFairseqModel` instance that adds support for
distributed training.
Anytime a method or attribute is called on this class we first try to
forward it to the underlying DistributedDataParallel instance, otherwise we
forward it to the original :class:`BaseFairseqModel` instance.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
def __init__(self, args, model):
super().__init__()
assert isinstance(model, BaseFairseqModel)
if args.no_c10d:
self.ddp_model = parallel.DistributedDataParallel(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
)
else:
self.ddp_model = parallel._DistributedDataParallelC10d(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
bucket_cap_mb=args.c10d_bucket_cap_mb,
)
def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.ddp_model.forward(*args, **kwargs)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
pass
try:
return self.ddp_model.__getattr__(name)
except AttributeError:
pass
return self.ddp_model.module.__getattr__(name)
...@@ -9,6 +9,7 @@ import importlib ...@@ -9,6 +9,7 @@ import importlib
import os import os
from .fairseq_optimizer import FairseqOptimizer from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer
OPTIMIZER_REGISTRY = {} OPTIMIZER_REGISTRY = {}
...@@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set() ...@@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params): def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params) params = list(filter(lambda p: p.requires_grad, params))
return OPTIMIZER_REGISTRY[args.optimizer](args, params) return OPTIMIZER_REGISTRY[args.optimizer](args, params)
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.optim import math
import torch
class FairseqOptimizer(object): class FairseqOptimizer(object):
...@@ -13,7 +15,7 @@ class FairseqOptimizer(object): ...@@ -13,7 +15,7 @@ class FairseqOptimizer(object):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__() super().__init__()
self.args = args self.args = args
self.params = params self.params = list(params)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -67,10 +69,25 @@ class FairseqOptimizer(object): ...@@ -67,10 +69,25 @@ class FairseqOptimizer(object):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group.update(optimizer_overrides) group.update(optimizer_overrides)
def backward(self, loss):
loss.backward()
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
for p in self.params:
p.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm."""
if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else:
return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params))
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step.""" """Performs a single optimization step."""
return self.optimizer.step(closure) self.optimizer.step(closure)
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
return self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -5,16 +5,9 @@ ...@@ -5,16 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import torch import torch
from fairseq import optim, utils from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler: class DynamicLossScaler:
...@@ -42,89 +35,97 @@ class DynamicLossScaler: ...@@ -42,89 +35,97 @@ class DynamicLossScaler:
return False return False
class FP16Trainer(Trainer): class FP16Optimizer(optim.FairseqOptimizer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion):
super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow def __init__(self, args, params, fp32_optimizer, fp32_params):
self.scaler = DynamicLossScaler(init_scale=2.**7) super().__init__(args, params)
self.meters['loss_scale'] = AverageMeter() self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.scaler = DynamicLossScaler(
init_scale=2.**7,
scale_window=(2**14 / args.distributed_world_size),
)
def _build_optimizer(self): @staticmethod
def build_optimizer(args, params):
# create FP32 copy of parameters and grads # create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params) total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size) fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0 offset = 0
for p in params: for p in params:
numel = p.data.numel() numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1)) fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params) fp32_params = torch.nn.Parameter(fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size) fp32_params.grad = fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params fp32_optimizer = optim.build_optimizer(args, [fp32_params])
self._optimizer = optim.build_optimizer(self.args, [self.fp32_params]) return FP16Optimizer(args, params, fp32_optimizer, fp32_params)
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state): @property
"""Save all training state in a checkpoint file.""" def optimizer(self):
extra_state['loss_scale'] = self.scaler.loss_scale return self.fp32_optimizer.optimizer
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): @property
"""Load all training state from a checkpoint file.""" def optimizer_config(self):
extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides) return self.fp32_optimizer.optimizer_config
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self): def get_lr(self):
# zero both the FP16 and FP32 grads return self.fp32_optimizer.get_lr()
self.model.zero_grad() # FP16
self.optimizer.zero_grad() # FP32
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom): def set_lr(self, lr):
# undo effect of dynamic loss scaling on gradients self.fp32_optimizer.set_lr(lr)
grad_denom *= self.scaler.loss_scale
if self.args.distributed_world_size > 1: def state_dict(self):
# flatten grads into a single buffer """Return the optimizer's state dict."""
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
# scale gradients to avoid overflow in all-reduce def load_state_dict(self, state_dict, optimizer_overrides=None):
flat_grads.div_(self.args.distributed_world_size) """Load an optimizer state dict.
grad_denom /= self.args.distributed_world_size
# all-reduce flat grads In general we should prefer the configuration of the existing optimizer
torch.distributed.all_reduce(flat_grads) instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
# copy grads back to FP32 def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
self.fp32_params.grad.data.copy_(flat_grads) if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.grad.data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else: else:
# single worker: copy grads directly to FP32 self.fp32_params.grad.data.mul_(c)
self._get_flat_grads(out=self.fp32_params.grad.data)
# rescale and clip grads def clip_grad_norm(self, max_norm):
self.fp32_params.grad.data.div_(grad_denom) """Clips gradient norm and updates dynamic loss scaler."""
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale # detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm) overflow = DynamicLossScaler.has_overflow(grad_norm)
...@@ -137,18 +138,27 @@ class FP16Trainer(Trainer): ...@@ -137,18 +138,27 @@ class FP16Trainer(Trainer):
'increasing the batch size.' 'increasing the batch size.'
).format(self.args.min_loss_scale)) ).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm return grad_norm
def _opt(self): def step(self, closure=None):
# take an optimization step using the FP32 params and grads """Performs a single optimization step."""
super()._opt() self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model # copy FP32 params back into FP16 model
offset = 0 offset = 0
for p in self.model.parameters(): for p in self.params:
if not p.requires_grad: if not p.requires_grad:
continue continue
numel = p.data.numel() numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
...@@ -183,6 +183,10 @@ def add_distributed_training_args(parser): ...@@ -183,6 +183,10 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)') help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int, group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)') help='which GPU to use (usually configured automatically)')
group.add_argument('--no-c10d', action='store_true',
help='don\'t use c10d distributed backend')
group.add_argument('--c10d-bucket-cap-mb', default=150, metavar='MB',
help='bucket size for c10d backend')
return group return group
......
...@@ -15,7 +15,7 @@ from itertools import chain ...@@ -15,7 +15,7 @@ from itertools import chain
import torch import torch
from fairseq import distributed_utils, optim, utils from fairseq import distributed_utils, models, optim, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.optim import lr_scheduler from fairseq.optim import lr_scheduler
...@@ -23,22 +23,27 @@ from fairseq.optim import lr_scheduler ...@@ -23,22 +23,27 @@ from fairseq.optim import lr_scheduler
class Trainer(object): class Trainer(object):
"""Main class for data parallel training. """Main class for data parallel training.
This class supports data parallel training, where multiple workers each This class supports synchronous distributed data parallel training,
have a full model replica and gradients are accumulated synchronously via where multiple workers each have a full model replica and gradients
torch.distributed.all_reduce. are accumulated across workers before each update. We use
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
communication of the gradients across workers.
""" """
def __init__(self, args, task, model, criterion): def __init__(self, args, task, model, criterion, dummy_batch):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
self.args = args self.args = args
self.task = task
# copy model and criterion to current device # copy model and criterion to current device
self.task = task
self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
if args.fp16:
self._model = model.half().cuda()
else:
self._model = model.cuda()
# initialize meters # initialize meters
self.meters = OrderedDict() self.meters = OrderedDict()
...@@ -53,14 +58,27 @@ class Trainer(object): ...@@ -53,14 +58,27 @@ class Trainer(object):
self.meters['gnorm'] = AverageMeter() # gradient norm self.meters['gnorm'] = AverageMeter() # gradient norm
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters['oom'] = AverageMeter() # out of memory
if args.fp16:
self.meters['loss_scale'] = AverageMeter() # dynamic loss scale
self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
self._buffered_stats = defaultdict(lambda: []) self._dummy_batch = dummy_batch
self._flat_grads = None
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
self._optimizer = None self._optimizer = None
self._wrapped_model = None
@property
def model(self):
if self._wrapped_model is None:
if self.args.distributed_world_size > 1:
self._wrapped_model = models.DistributedFairseqModel(
self.args, self._model,
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property @property
def optimizer(self): def optimizer(self):
...@@ -69,7 +87,17 @@ class Trainer(object): ...@@ -69,7 +87,17 @@ class Trainer(object):
return self._optimizer return self._optimizer
def _build_optimizer(self): def _build_optimizer(self):
if self.args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster')
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer) self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
...@@ -77,31 +105,27 @@ class Trainer(object): ...@@ -77,31 +105,27 @@ class Trainer(object):
if distributed_utils.is_master(self.args): # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
extra_state['train_meters'] = self.meters extra_state['train_meters'] = self.meters
utils.save_state( utils.save_state(
filename, self.args, self.model, self.criterion, self.optimizer, filename, self.args, self.get_model(), self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state, self.lr_scheduler, self._num_updates, self._optim_history, extra_state,
) )
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = \ extra_state, self._optim_history, last_optim_state = \
utils.load_model_state(filename, self.model) utils.load_model_state(filename, self.get_model())
if last_optim_state is not None and not reset_optimizer: if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed # rebuild optimizer after loading model, since params may have changed
self._build_optimizer() self._build_optimizer()
# only reload optimizer and lr_scheduler if they match # only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1] last_optim = self._optim_history[-1]
assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \ assert last_optim['criterion_name'] == self.criterion.__class__.__name__, \
'criterion does not match; please reset the optimizer (--reset-optimizer)' 'criterion does not match; please reset the optimizer (--reset-optimizer)'
assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \ assert last_optim['optimizer_name'] == self.optimizer.__class__.__name__, \
'optimizer does not match; please reset the optimizer (--reset-optimizer)' 'optimizer does not match; please reset the optimizer (--reset-optimizer)'
if not reset_lr_scheduler: if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state']) self.lr_scheduler.load_state_dict(last_optim['lr_scheduler_state'])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
...@@ -117,7 +141,7 @@ class Trainer(object): ...@@ -117,7 +141,7 @@ class Trainer(object):
return extra_state return extra_state
def train_step(self, sample, update_params=True, dummy_batch=False): def train_step(self, samples, dummy_batch=False):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get # Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
...@@ -125,230 +149,164 @@ class Trainer(object): ...@@ -125,230 +149,164 @@ class Trainer(object):
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
self.model.train()
self.zero_grad()
if not dummy_batch: if not dummy_batch:
self.meters['train_wall'].start() self.meters['train_wall'].start()
# forward and backward pass # forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample) sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample) if sample is None:
oom_bwd = self._backward(loss) # when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
# buffer stats and logging outputs sample = self._prepare_sample(self._dummy_batch)
self._buffered_stats['sample_sizes'].append(sample_size) ignore_grad = True
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params:
agg_logging_output = self._update_params()
else: else:
agg_logging_output = None # buffering updates ignore_grad = False
if not dummy_batch: try:
self.meters['train_wall'].stop() # forward
loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
if ignore_grad:
loss *= 0
if self.args.distributed_world_size > 1:
# only all-reduce gradients in the last backwards pass
if i < len(samples) - 1:
self.model.need_reduction = False
else:
self.model.need_reduction = True
# backward
self.optimizer.backward(loss)
if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
ooms += 1
self.zero_grad()
else:
raise e
return agg_logging_output if dummy_batch:
return None
def _update_params(self):
# gather logging outputs from all replicas # gather logging outputs from all replicas
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
lambda l: list(chain.from_iterable(l)), [logging_outputs, sample_sizes, ooms],
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
)) ))
) logging_outputs = list(chain.from_iterable(logging_outputs))
ooms_fwd = sum(ooms_fwd) sample_sizes = list(chain.from_iterable(sample_sizes))
ooms_bwd = sum(ooms_bwd) ooms = sum(ooms)
if ooms_fwd == self.args.distributed_world_size: if ooms == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch') print('| WARNING: OOM in all workers, skipping update')
self.zero_grad() self.zero_grad()
return None return None
# aggregate stats and logging outputs # aggregate logging outputs and sample sizes
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) sample_size = self.criterion.__class__.grad_denom(sample_sizes)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
'Please update the {}.aggregate_logging_outputs() method to '
'return ntokens and nsentences'
).format(self.criterion.__class__.__name__))
try: try:
# all-reduce and rescale gradients, then take an optimization step # normalize grads by sample size
grad_norm = self._all_reduce_and_rescale(grad_denom) self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
self._opt()
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
# take an optimization step
self.optimizer.step()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# update meters # update meters
ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0)
self.meters['wps'].update(ntokens) self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.) self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens) self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences) self.meters['bsz'].update(nsentences)
if grad_norm is not None:
self.meters['gnorm'].update(grad_norm) self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.) self.meters['clip'].update(
self.meters['oom'].update(ooms_fwd + ooms_bwd) 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
)
# update loss meters for training self.meters['oom'].update(ooms)
if 'loss' in agg_logging_output: self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom) self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
except OverflowError as e: except OverflowError as e:
self.zero_grad()
print('| WARNING: overflow detected, ' + str(e)) print('| WARNING: overflow detected, ' + str(e))
self.clear_buffered_stats()
return agg_logging_output
def _forward(self, sample, eval=False):
loss = None
sample_size = 0
logging_output = {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
oom = 0
try:
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
else:
raise e
return loss, sample_size, logging_output, oom
def _backward(self, loss):
oom = 0
if loss is not None:
try:
# backward pass
loss.backward()
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
self.zero_grad() self.zero_grad()
else: logging_output = None
raise e
return oom
def _all_reduce_and_rescale(self, grad_denom): if self.args.fp16:
# flatten grads into a single buffer and all-reduce self.meters['loss_scale'].reset()
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
if self.args.distributed_world_size > 1:
torch.distributed.all_reduce(flat_grads)
# rescale and clip gradients
flat_grads.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(flat_grads, self.args.clip_norm)
# copy grads back into model parameters
self._set_flat_grads(flat_grads)
return grad_norm
def _get_grads(self):
grads = []
for name, p in self.model.named_parameters():
if not p.requires_grad:
continue
if p.grad is None:
print('WARNING: model parameter did not receive gradient: ' + name + '. '
'Check that you\'re using the param in the forward pass or set requires_grad=False')
grads.append(p.new_zeros(p.shape))
else:
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
grads = self._get_grads()
if out is None:
grads_size = sum(g.numel() for g in grads)
out = grads[0].new(grads_size).zero_()
offset = 0
for g in grads:
numel = g.numel()
out[offset:offset+numel].copy_(g.view(-1))
offset += numel
return out[:offset]
def _set_flat_grads(self, new_grads):
grads = self._get_grads()
offset = 0
for g in grads:
numel = g.numel()
g.copy_(new_grads[offset:offset+numel].view_as(g))
offset += numel
def _opt(self):
# take an optimization step
self.optimizer.step()
self.zero_grad()
self._num_updates += 1
# update learning rate self.meters['train_wall'].stop()
self.lr_scheduler.step_update(self._num_updates)
return logging_output
def valid_step(self, sample): def valid_step(self, sample):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
# forward pass self.model.eval()
logging_output, sample_size = {}, 0
with torch.no_grad():
sample = self._prepare_sample(sample) sample = self._prepare_sample(sample)
_loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True) if sample is None:
assert not oom_fwd, 'Ran out of memory during validation' sample = self._prepare_sample(self._dummy_batch)
_loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
# gather logging outputs from all GPUs # gather logging outputs from all replicas
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list( logging_output, sample_size = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output) [logging_output, sample_size],
)) ))
logging_output = list(logging_output)
sample_size = list(sample_size)
else: else:
sample_sizes = [sample_size] logging_output = [logging_output]
logging_outputs = [logging_output] sample_size = [sample_size]
# aggregate stats and logging outputs # aggregate logging outputs and sample sizes
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_output)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes) sample_size = self.criterion.__class__.grad_denom(sample_size)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# update loss meters for validation # update meters for validation
if 'loss' in agg_logging_output: ntokens = logging_output.get('ntokens', 0)
self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom) self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size)
# criterions can optionally log the NLL loss too self.meters['valid_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
if 'nll_loss' in agg_logging_output:
self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
return agg_logging_output return logging_output
def dummy_train_step(self, dummy_batch): def dummy_train_step(self, dummy_batch):
"""Dummy training step for warming caching allocator.""" """Dummy training step for warming caching allocator."""
self.train_step(dummy_batch, update_params=False, dummy_batch=True) self.train_step(dummy_batch, dummy_batch=True)
self.zero_grad() self.zero_grad()
self.clear_buffered_stats()
def zero_grad(self): def zero_grad(self):
self.optimizer.zero_grad() self.optimizer.zero_grad()
def clear_buffered_stats(self):
self._buffered_stats.clear()
def lr_step(self, epoch, val_loss=None): def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate based on the validation loss.""" """Adjust the learning rate based on the validation loss."""
return self.lr_scheduler.step(epoch, val_loss) return self.lr_scheduler.step(epoch, val_loss)
...@@ -362,8 +320,8 @@ class Trainer(object): ...@@ -362,8 +320,8 @@ class Trainer(object):
return self.optimizer.get_lr() return self.optimizer.get_lr()
def get_model(self): def get_model(self):
"""Get the model replica.""" """Get the (non-wrapped) model instance."""
return self.model return self._model
def get_meter(self, name): def get_meter(self, name):
"""Get a specific meter by name.""" """Get a specific meter by name."""
......
...@@ -19,8 +19,10 @@ from train import main as single_process_main ...@@ -19,8 +19,10 @@ from train import main as single_process_main
def main(args): def main(args):
# Set distributed training parameters for a single node. # Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count() args.distributed_world_size = torch.cuda.device_count()
args.distributed_init_method = 'tcp://localhost:{port}'.format( port = random.randint(10000, 20000)
port=random.randint(10000, 20000)) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
mp = torch.multiprocessing.get_context('spawn') mp = torch.multiprocessing.get_context('spawn')
......
...@@ -35,7 +35,7 @@ bleu = Extension( ...@@ -35,7 +35,7 @@ bleu = Extension(
setup( setup(
name='fairseq', name='fairseq',
version='0.5.0', version='0.6.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit', description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme, long_description=readme,
license=license, license=license,
......
...@@ -16,7 +16,7 @@ import math ...@@ -16,7 +16,7 @@ import math
import torch import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer from fairseq.data import iterators
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
...@@ -43,16 +43,17 @@ def main(args): ...@@ -43,16 +43,17 @@ def main(args):
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters()))) print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
# Build trainer # Build trainer
if args.fp16: trainer = Trainer(args, task, model, criterion, dummy_batch)
if torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16,'
' please switch to FP32 which is likely to be faster')
trainer = FP16Trainer(args, task, model, criterion)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size)) print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens, args.max_tokens,
...@@ -60,10 +61,6 @@ def main(args): ...@@ -60,10 +61,6 @@ def main(args):
)) ))
# Initialize dataloader # Initialize dataloader
max_positions = utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
)
epoch_itr = task.get_batch_iterator( epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset), dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
...@@ -78,9 +75,7 @@ def main(args): ...@@ -78,9 +75,7 @@ def main(args):
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr): if not load_checkpoint(args, trainer, epoch_itr):
# Send a dummy batch to warm the caching allocator trainer.dummy_train_step([dummy_batch])
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -110,32 +105,32 @@ def main(args): ...@@ -110,32 +105,32 @@ def main(args):
def train(args, trainer, task, epoch_itr): def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Initialize data iterator # Update parameters every N batches
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq): if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1] update_freq = args.update_freq[epoch_itr.epoch - 1]
else: else:
update_freq = args.update_freq[-1] update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
)
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0] first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
num_batches = len(epoch_itr) num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0: log_output = trainer.train_step(samples)
# buffer updates according to --update-freq if log_output is None:
trainer.train_step(sample, update_params=False)
continue continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats # log mid-epoch stats
stats = get_training_stats(trainer) stats = get_training_stats(trainer)
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']: if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue # these are already logged above continue # these are already logged above
if 'loss' in k: if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size']) extra_meters[k].update(v, log_output['sample_size'])
...@@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr): ...@@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr):
progress.print(stats) progress.print(stats)
# reset training meters # reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip', 'gnorm']: for k in [
'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
]:
meter = trainer.get_meter(k) meter = trainer.get_meter(k)
if meter is not None: if meter is not None:
meter.reset() meter.reset()
...@@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']: if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue continue
extra_meters[k].update(v) extra_meters[k].update(v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment