"git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "bb455d7c1601c55b4d8b44a6ada8906bbb710551"
Commit 1082ba35 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Switch to DistributedDataParallelC10d and bump version 0.5.0 -> 0.6.0

- no more FP16Trainer, we just have an FP16Optimizer wrapper
- most of the distributed code is moved to a new wrapper class called DistributedFairseqModel, which behaves like DistributedDataParallel and a FairseqModel at the same time
- Trainer now requires an extra dummy_batch argument at initialization, which we do fwd/bwd on when there's an uneven number of batches per worker. We hide the gradients from these dummy batches by multiplying the loss by 0
- Trainer.train_step now takes a list of samples, which will allow cleaner --update-freq
parent 311d2c6c
...@@ -30,7 +30,7 @@ def main(args): ...@@ -30,7 +30,7 @@ def main(args):
raise e raise e
except FileNotFoundError as e: # Slurm is not installed except FileNotFoundError as e: # Slurm is not installed
pass pass
if args.distributed_init_method is None: if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port ' raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training') 'must be specified for distributed training')
......
...@@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/' ...@@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.5.0' version = '0.6.0'
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '0.5.0' release = '0.6.0'
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
......
...@@ -36,5 +36,7 @@ Iterators ...@@ -36,5 +36,7 @@ Iterators
:members: :members:
.. autoclass:: fairseq.data.EpochBatchIterator .. autoclass:: fairseq.data.EpochBatchIterator
:members: :members:
.. autoclass:: fairseq.data.GroupedIterator
:members:
.. autoclass:: fairseq.data.ShardedIterator .. autoclass:: fairseq.data.ShardedIterator
:members: :members:
...@@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion):
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = { agg_output = {
'loss': loss_sum / sample_size / math.log(2), 'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
if sample_size != ntokens: if sample_size != ntokens:
......
...@@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion):
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs) loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = { agg_output = {
'loss': loss_sum / sample_size / math.log(2), 'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
if sample_size != ntokens: if sample_size != ntokens:
......
...@@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, 'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'], 'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss return loss, nll_loss
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return { return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2), 'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
...@@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset ...@@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [ __all__ = [
'CountingIterator', 'CountingIterator',
'Dictionary', 'Dictionary',
'EpochBatchIterator', 'EpochBatchIterator',
'FairseqDataset', 'FairseqDataset',
'GroupedIterator',
'IndexedDataset', 'IndexedDataset',
'IndexedInMemoryDataset', 'IndexedInMemoryDataset',
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'MonolingualDataset', 'MonolingualDataset',
'TokenBlockDataset',
'ShardedIterator', 'ShardedIterator',
'TokenBlockDataset',
] ]
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import itertools
import math
import numpy as np import numpy as np
import torch import torch
...@@ -150,6 +151,36 @@ class EpochBatchIterator(object): ...@@ -150,6 +151,36 @@ class EpochBatchIterator(object):
)) ))
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
"""
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.itr = iter(iterable)
self.chunk_size = chunk_size
def __len__(self):
return self._len
def __iter__(self):
return self
def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk
class ShardedIterator(object): class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length. """A sharded wrapper around an iterable, padded to length.
......
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
import pickle import pickle
import torch.distributed import torch
from torch import distributed
from torch.distributed import group
from fairseq import utils from fairseq import utils
...@@ -16,22 +18,39 @@ def is_master(args): ...@@ -16,22 +18,39 @@ def is_master(args):
return args.distributed_rank == 0 return args.distributed_rank == 0
_use_c10d = [None]
def distributed_init(args): def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if _use_c10d[0] is None:
_use_c10d[0] = not args.no_c10d
if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
_use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')
print('| distributed init (rank {}): {}'.format( print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True) args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group( if _use_c10d[0]:
backend=args.distributed_backend, init_method=args.distributed_init_method, distributed.c10d.init_process_group(
world_size=args.distributed_world_size, rank=args.distributed_rank) backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
else: else:
torch.distributed.init_process_group( distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method, backend=args.distributed_backend,
world_size=args.distributed_world_size) init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
args.distributed_rank = torch.distributed.get_rank()
if not is_master(args): if not is_master(args):
suppress_output() suppress_output()
...@@ -52,35 +71,77 @@ def suppress_output(): ...@@ -52,35 +71,77 @@ def suppress_output():
__builtin__.print = print __builtin__.print = print
def all_gather_list(data, max_size=16384): def get_rank():
"""Gathers arbitrary data from all nodes into a list.""" if _use_c10d[0]:
world_size = torch.distributed.get_world_size() return distributed.c10d.get_rank()
if not hasattr(all_gather_list, '_in_buffer') or \ else:
max_size != all_gather_list._in_buffer.size(): return distributed.get_rank()
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size) def get_world_size():
for i in range(world_size) if _use_c10d[0]:
] return distributed.c10d.get_world_size()
in_buffer = all_gather_list._in_buffer else:
out_buffers = all_gather_list._out_buffers return distributed.get_world_size()
def get_default_group():
if _use_c10d[0]:
return distributed.c10d.group.WORLD
else:
return distributed.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group)
else:
return distributed.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
buffer = all_gather_list._buffer
buffer.zero_()
enc = pickle.dumps(data) enc = pickle.dumps(data)
enc_size = len(enc) enc_size = len(enc)
if enc_size + 2 > max_size: if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2)) raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256 assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda()) buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[1] = enc_size % 255
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
all_reduce(buffer, group=group)
result = [] result = []
for i in range(world_size): for i in range(world_size):
out_buffer = out_buffers[i] out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1]) size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append( if size > 0:
pickle.loads(bytes(out_buffer[2:size+2].tolist())) result.append(
) pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result return result
...@@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 ...@@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401 from .composite_encoder import CompositeEncoder # noqa: F401
from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.distributed import c10d
from torch.nn import parallel
from . import BaseFairseqModel
class DistributedFairseqModel(BaseFairseqModel):
"""
A wrapper around a :class:`BaseFairseqModel` instance that adds support for
distributed training.
Anytime a method or attribute is called on this class we first try to
forward it to the underlying DistributedDataParallel instance, otherwise we
forward it to the original :class:`BaseFairseqModel` instance.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
def __init__(self, args, model):
super().__init__()
assert isinstance(model, BaseFairseqModel)
if args.no_c10d:
self.ddp_model = parallel.DistributedDataParallel(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
)
else:
self.ddp_model = parallel._DistributedDataParallelC10d(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
bucket_cap_mb=args.c10d_bucket_cap_mb,
)
def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.ddp_model.forward(*args, **kwargs)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
pass
try:
return self.ddp_model.__getattr__(name)
except AttributeError:
pass
return self.ddp_model.module.__getattr__(name)
...@@ -9,6 +9,7 @@ import importlib ...@@ -9,6 +9,7 @@ import importlib
import os import os
from .fairseq_optimizer import FairseqOptimizer from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer
OPTIMIZER_REGISTRY = {} OPTIMIZER_REGISTRY = {}
...@@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set() ...@@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params): def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params) params = list(filter(lambda p: p.requires_grad, params))
return OPTIMIZER_REGISTRY[args.optimizer](args, params) return OPTIMIZER_REGISTRY[args.optimizer](args, params)
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.optim import math
import torch
class FairseqOptimizer(object): class FairseqOptimizer(object):
...@@ -13,7 +15,7 @@ class FairseqOptimizer(object): ...@@ -13,7 +15,7 @@ class FairseqOptimizer(object):
def __init__(self, args, params): def __init__(self, args, params):
super().__init__() super().__init__()
self.args = args self.args = args
self.params = params self.params = list(params)
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -67,10 +69,25 @@ class FairseqOptimizer(object): ...@@ -67,10 +69,25 @@ class FairseqOptimizer(object):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group.update(optimizer_overrides) group.update(optimizer_overrides)
def backward(self, loss):
loss.backward()
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
for p in self.params:
p.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm."""
if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else:
return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params))
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step.""" """Performs a single optimization step."""
return self.optimizer.step(closure) self.optimizer.step(closure)
def zero_grad(self): def zero_grad(self):
"""Clears the gradients of all optimized parameters.""" """Clears the gradients of all optimized parameters."""
return self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -5,16 +5,9 @@ ...@@ -5,16 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import torch import torch
from fairseq import optim, utils from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler: class DynamicLossScaler:
...@@ -42,89 +35,97 @@ class DynamicLossScaler: ...@@ -42,89 +35,97 @@ class DynamicLossScaler:
return False return False
class FP16Trainer(Trainer): class FP16Optimizer(optim.FairseqOptimizer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion):
super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow def __init__(self, args, params, fp32_optimizer, fp32_params):
self.scaler = DynamicLossScaler(init_scale=2.**7) super().__init__(args, params)
self.meters['loss_scale'] = AverageMeter() self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.scaler = DynamicLossScaler(
init_scale=2.**7,
scale_window=(2**14 / args.distributed_world_size),
)
def _build_optimizer(self): @staticmethod
def build_optimizer(args, params):
# create FP32 copy of parameters and grads # create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params) total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size) fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0 offset = 0
for p in params: for p in params:
numel = p.data.numel() numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1)) fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params) fp32_params = torch.nn.Parameter(fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size) fp32_params.grad = fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params fp32_optimizer = optim.build_optimizer(args, [fp32_params])
self._optimizer = optim.build_optimizer(self.args, [self.fp32_params]) return FP16Optimizer(args, params, fp32_optimizer, fp32_params)
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
@property
def save_checkpoint(self, filename, extra_state): def optimizer(self):
"""Save all training state in a checkpoint file.""" return self.fp32_optimizer.optimizer
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state) @property
def optimizer_config(self):
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): return self.fp32_optimizer.optimizer_config
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides) def get_lr(self):
if extra_state is not None and 'loss_scale' in extra_state: return self.fp32_optimizer.get_lr()
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def zero_grad(self):
# zero both the FP16 and FP32 grads def state_dict(self):
self.model.zero_grad() # FP16 """Return the optimizer's state dict."""
self.optimizer.zero_grad() # FP32 state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
def _backward(self, loss): return state_dict
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale) def load_state_dict(self, state_dict, optimizer_overrides=None):
if loss is not None: """Load an optimizer state dict.
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale In general we should prefer the configuration of the existing optimizer
return super()._backward(loss) instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
def _all_reduce_and_rescale(self, grad_denom): optimizer args.
# undo effect of dynamic loss scaling on gradients """
grad_denom *= self.scaler.loss_scale if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
if self.args.distributed_world_size > 1: self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
# flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads) def backward(self, loss):
loss = loss * self.scaler.loss_scale
# scale gradients to avoid overflow in all-reduce loss.backward()
flat_grads.div_(self.args.distributed_world_size) self._needs_sync = True
grad_denom /= self.args.distributed_world_size
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
# all-reduce flat grads if self._needs_sync:
torch.distributed.all_reduce(flat_grads) # copy FP16 grads to FP32
offset = 0
# copy grads back to FP32 for p in self.params:
self.fp32_params.grad.data.copy_(flat_grads) if not p.requires_grad:
continue
numel = p.grad.data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else: else:
# single worker: copy grads directly to FP32 self.fp32_params.grad.data.mul_(c)
self._get_flat_grads(out=self.fp32_params.grad.data)
# rescale and clip grads def clip_grad_norm(self, max_norm):
self.fp32_params.grad.data.div_(grad_denom) """Clips gradient norm and updates dynamic loss scaler."""
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm) self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale # detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm) overflow = DynamicLossScaler.has_overflow(grad_norm)
...@@ -137,18 +138,27 @@ class FP16Trainer(Trainer): ...@@ -137,18 +138,27 @@ class FP16Trainer(Trainer):
'increasing the batch size.' 'increasing the batch size.'
).format(self.args.min_loss_scale)) ).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale)) raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm return grad_norm
def _opt(self): def step(self, closure=None):
# take an optimization step using the FP32 params and grads """Performs a single optimization step."""
super()._opt() self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model # copy FP32 params back into FP16 model
offset = 0 offset = 0
for p in self.model.parameters(): for p in self.params:
if not p.requires_grad: if not p.requires_grad:
continue continue
numel = p.data.numel() numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data)) p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
...@@ -183,6 +183,10 @@ def add_distributed_training_args(parser): ...@@ -183,6 +183,10 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)') help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int, group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)') help='which GPU to use (usually configured automatically)')
group.add_argument('--no-c10d', action='store_true',
help='don\'t use c10d distributed backend')
group.add_argument('--c10d-bucket-cap-mb', default=150, metavar='MB',
help='bucket size for c10d backend')
return group return group
......
This diff is collapsed.
...@@ -19,8 +19,10 @@ from train import main as single_process_main ...@@ -19,8 +19,10 @@ from train import main as single_process_main
def main(args): def main(args):
# Set distributed training parameters for a single node. # Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count() args.distributed_world_size = torch.cuda.device_count()
args.distributed_init_method = 'tcp://localhost:{port}'.format( port = random.randint(10000, 20000)
port=random.randint(10000, 20000)) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
mp = torch.multiprocessing.get_context('spawn') mp = torch.multiprocessing.get_context('spawn')
......
...@@ -35,7 +35,7 @@ bleu = Extension( ...@@ -35,7 +35,7 @@ bleu = Extension(
setup( setup(
name='fairseq', name='fairseq',
version='0.5.0', version='0.6.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit', description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme, long_description=readme,
license=license, license=license,
......
...@@ -16,7 +16,7 @@ import math ...@@ -16,7 +16,7 @@ import math
import torch import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer from fairseq.data import iterators
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
...@@ -43,16 +43,17 @@ def main(args): ...@@ -43,16 +43,17 @@ def main(args):
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters()))) print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
# Build trainer # Build trainer
if args.fp16: trainer = Trainer(args, task, model, criterion, dummy_batch)
if torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16,'
' please switch to FP32 which is likely to be faster')
trainer = FP16Trainer(args, task, model, criterion)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size)) print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens, args.max_tokens,
...@@ -60,10 +61,6 @@ def main(args): ...@@ -60,10 +61,6 @@ def main(args):
)) ))
# Initialize dataloader # Initialize dataloader
max_positions = utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
)
epoch_itr = task.get_batch_iterator( epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset), dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
...@@ -78,9 +75,7 @@ def main(args): ...@@ -78,9 +75,7 @@ def main(args):
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr): if not load_checkpoint(args, trainer, epoch_itr):
# Send a dummy batch to warm the caching allocator trainer.dummy_train_step([dummy_batch])
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -110,32 +105,32 @@ def main(args): ...@@ -110,32 +105,32 @@ def main(args):
def train(args, trainer, task, epoch_itr): def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch.""" """Train the model for one epoch."""
# Initialize data iterator # Update parameters every N batches
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq): if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1] update_freq = args.update_freq[epoch_itr.epoch - 1]
else: else:
update_freq = args.update_freq[-1] update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
)
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0] first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
num_batches = len(epoch_itr) num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch): for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0: log_output = trainer.train_step(samples)
# buffer updates according to --update-freq if log_output is None:
trainer.train_step(sample, update_params=False)
continue continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats # log mid-epoch stats
stats = get_training_stats(trainer) stats = get_training_stats(trainer)
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']: if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue # these are already logged above continue # these are already logged above
if 'loss' in k: if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size']) extra_meters[k].update(v, log_output['sample_size'])
...@@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr): ...@@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr):
progress.print(stats) progress.print(stats)
# reset training meters # reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip', 'gnorm']: for k in [
'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
]:
meter = trainer.get_meter(k) meter = trainer.get_meter(k)
if meter is not None: if meter is not None:
meter.reset() meter.reset()
...@@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
log_output = trainer.valid_step(sample) log_output = trainer.valid_step(sample)
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']: if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue continue
extra_meters[k].update(v) extra_meters[k].update(v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment