Commit 1082ba35 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Switch to DistributedDataParallelC10d and bump version 0.5.0 -> 0.6.0

- no more FP16Trainer, we just have an FP16Optimizer wrapper
- most of the distributed code is moved to a new wrapper class called DistributedFairseqModel, which behaves like DistributedDataParallel and a FairseqModel at the same time
- Trainer now requires an extra dummy_batch argument at initialization, which we do fwd/bwd on when there's an uneven number of batches per worker. We hide the gradients from these dummy batches by multiplying the loss by 0
- Trainer.train_step now takes a list of samples, which will allow cleaner --update-freq
parent 311d2c6c
......@@ -30,7 +30,7 @@ def main(args):
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None:
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
......
......@@ -60,9 +60,9 @@ github_doc_root = 'https://github.com/pytorch/fairseq/tree/master/docs/'
# built documents.
#
# The short X.Y version.
version = '0.5.0'
version = '0.6.0'
# The full version, including alpha/beta/rc tags.
release = '0.5.0'
release = '0.6.0'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
......
......@@ -36,5 +36,7 @@ Iterators
:members:
.. autoclass:: fairseq.data.EpochBatchIterator
:members:
.. autoclass:: fairseq.data.GroupedIterator
:members:
.. autoclass:: fairseq.data.ShardedIterator
:members:
......@@ -54,6 +54,7 @@ class AdaptiveLoss(FairseqCriterion):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......@@ -63,9 +64,12 @@ class AdaptiveLoss(FairseqCriterion):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:
......
......@@ -37,6 +37,7 @@ class CrossEntropyCriterion(FairseqCriterion):
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......@@ -46,9 +47,12 @@ class CrossEntropyCriterion(FairseqCriterion):
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
if sample_size != ntokens:
......
......@@ -40,6 +40,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
return loss, sample_size, logging_output
......@@ -58,14 +59,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
return loss, nll_loss
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
......@@ -12,18 +12,24 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator
from .iterators import (
CountingIterator,
EpochBatchIterator,
GroupedIterator,
ShardedIterator,
)
__all__ = [
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'GroupedIterator',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
'TokenBlockDataset',
'ShardedIterator',
'TokenBlockDataset',
]
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import itertools
import math
import numpy as np
import torch
......@@ -150,6 +151,36 @@ class EpochBatchIterator(object):
))
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Args:
iterable (iterable): iterable to wrap
chunk_size (int): size of each chunk
"""
def __init__(self, iterable, chunk_size):
self._len = int(math.ceil(len(iterable) / float(chunk_size)))
self.itr = iter(iterable)
self.chunk_size = chunk_size
def __len__(self):
return self._len
def __iter__(self):
return self
def __next__(self):
chunk = []
try:
for _ in range(self.chunk_size):
chunk.append(next(self.itr))
except StopIteration as e:
if len(chunk) == 0:
raise e
return chunk
class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length.
......
......@@ -7,7 +7,9 @@
import pickle
import torch.distributed
import torch
from torch import distributed
from torch.distributed import group
from fairseq import utils
......@@ -16,22 +18,39 @@ def is_master(args):
return args.distributed_rank == 0
_use_c10d = [None]
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if _use_c10d[0] is None:
_use_c10d[0] = not args.no_c10d
if _use_c10d[0] and not hasattr(torch.nn.parallel, '_DistributedDataParallelC10d'):
_use_c10d[0] = False
print('WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, rank=args.distributed_rank)
if _use_c10d[0]:
distributed.c10d.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
else:
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
distributed.init_process_group(
backend=args.distributed_backend,
init_method=args.distributed_init_method,
world_size=args.distributed_world_size,
rank=args.distributed_rank,
)
args.distributed_rank = torch.distributed.get_rank()
if not is_master(args):
suppress_output()
......@@ -52,35 +71,77 @@ def suppress_output():
__builtin__.print = print
def all_gather_list(data, max_size=16384):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
def get_rank():
if _use_c10d[0]:
return distributed.c10d.get_rank()
else:
return distributed.get_rank()
def get_world_size():
if _use_c10d[0]:
return distributed.c10d.get_world_size()
else:
return distributed.get_world_size()
def get_default_group():
if _use_c10d[0]:
return distributed.c10d.group.WORLD
else:
return distributed.group.WORLD
def all_reduce(tensor, group=None):
if group is None:
group = get_default_group()
if _use_c10d[0]:
return distributed.c10d.all_reduce(tensor, group=group)
else:
return distributed.all_reduce(tensor, group=group)
def all_gather_list(data, group=None, max_size=16384):
"""Gathers arbitrary data from all nodes into a list.
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python
data. Note that *data* must be picklable.
Args:
data (Any): data from the local worker to be gathered on other workers
group (optional): group of the collective
max_size (int, optional): maximum size of the data to be gathered
across workers
"""
rank = get_rank()
world_size = get_world_size()
buffer_size = max_size * world_size
if not hasattr(all_gather_list, '_buffer') or \
all_gather_list._buffer.numel() < buffer_size:
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size)
buffer = all_gather_list._buffer
buffer.zero_()
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
buffer_rank = buffer[rank * max_size : (rank + 1) * max_size]
buffer_rank[0] = enc_size // 255 # this encoding works for max_size < 65k
buffer_rank[1] = enc_size % 255
buffer_rank[2:enc_size+2] = torch.ByteTensor(list(enc))
all_reduce(buffer, group=group)
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
out_buffer = buffer[i * max_size : (i + 1) * max_size]
size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
if size > 0:
result.append(
pickle.loads(bytes(out_buffer[2:size+2].tolist()))
)
return result
......@@ -15,6 +15,7 @@ from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401
MODEL_REGISTRY = {}
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch.distributed import c10d
from torch.nn import parallel
from . import BaseFairseqModel
class DistributedFairseqModel(BaseFairseqModel):
"""
A wrapper around a :class:`BaseFairseqModel` instance that adds support for
distributed training.
Anytime a method or attribute is called on this class we first try to
forward it to the underlying DistributedDataParallel instance, otherwise we
forward it to the original :class:`BaseFairseqModel` instance.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
def __init__(self, args, model):
super().__init__()
assert isinstance(model, BaseFairseqModel)
if args.no_c10d:
self.ddp_model = parallel.DistributedDataParallel(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
)
else:
self.ddp_model = parallel._DistributedDataParallelC10d(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
broadcast_buffers=False,
bucket_cap_mb=args.c10d_bucket_cap_mb,
)
def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs)
def forward(self, *args, **kwargs):
return self.ddp_model.forward(*args, **kwargs)
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
pass
try:
return self.ddp_model.__getattr__(name)
except AttributeError:
pass
return self.ddp_model.module.__getattr__(name)
......@@ -9,6 +9,7 @@ import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
from .fp16_optimizer import FP16Optimizer
OPTIMIZER_REGISTRY = {}
......@@ -16,7 +17,7 @@ OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
params = list(filter(lambda p: p.requires_grad, params))
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
......
......@@ -5,7 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
import math
import torch
class FairseqOptimizer(object):
......@@ -13,7 +15,7 @@ class FairseqOptimizer(object):
def __init__(self, args, params):
super().__init__()
self.args = args
self.params = params
self.params = list(params)
@staticmethod
def add_args(parser):
......@@ -67,10 +69,25 @@ class FairseqOptimizer(object):
for group in self.optimizer.param_groups:
group.update(optimizer_overrides)
def backward(self, loss):
loss.backward()
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
for p in self.params:
p.grad.data.mul_(c)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm."""
if max_norm > 0:
return torch.nn.utils.clip_grad_norm_(self.params, max_norm)
else:
return math.sqrt(sum(p.grad.data.norm()**2 for p in self.params))
def step(self, closure=None):
"""Performs a single optimization step."""
return self.optimizer.step(closure)
self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
return self.optimizer.zero_grad()
self.optimizer.zero_grad()
......@@ -5,16 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import torch
from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler:
......@@ -42,89 +35,97 @@ class DynamicLossScaler:
return False
class FP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion):
super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
class FP16Optimizer(optim.FairseqOptimizer):
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
def __init__(self, args, params, fp32_optimizer, fp32_params):
super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.scaler = DynamicLossScaler(
init_scale=2.**7,
scale_window=(2**14 / args.distributed_world_size),
)
def _build_optimizer(self):
@staticmethod
def build_optimizer(args, params):
# create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size)
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self._optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename, reset_optimizer, reset_lr_scheduler, optimizer_overrides)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self):
# zero both the FP16 and FP32 grads
self.model.zero_grad() # FP16
self.optimizer.zero_grad() # FP32
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale
if self.args.distributed_world_size > 1:
# flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
# scale gradients to avoid overflow in all-reduce
flat_grads.div_(self.args.distributed_world_size)
grad_denom /= self.args.distributed_world_size
# all-reduce flat grads
torch.distributed.all_reduce(flat_grads)
# copy grads back to FP32
self.fp32_params.grad.data.copy_(flat_grads)
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return FP16Optimizer(args, params, fp32_optimizer, fp32_params)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
state_dict['loss_scale'] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if 'loss_scale' in state_dict:
self.scaler.loss_scale = state_dict['loss_scale']
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
loss = loss * self.scaler.loss_scale
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self, multiply_grads=1.):
if self._needs_sync:
# copy FP16 grads to FP32
offset = 0
for p in self.params:
if not p.requires_grad:
continue
numel = p.grad.data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1))
offset += numel
# correct for dynamic loss scaler
self.fp32_params.grad.data.mul_(multiply_grads / self.scaler.loss_scale)
self._needs_sync = False
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._sync_fp16_grads_to_fp32(c)
else:
# single worker: copy grads directly to FP32
self._get_flat_grads(out=self.fp32_params.grad.data)
self.fp32_params.grad.data.mul_(c)
# rescale and clip grads
self.fp32_params.grad.data.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)
def clip_grad_norm(self, max_norm):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, max_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
......@@ -137,18 +138,27 @@ class FP16Trainer(Trainer):
'increasing the batch size.'
).format(self.args.min_loss_scale))
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
def step(self, closure=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
self.fp32_optimizer.step(closure)
# copy FP32 params back into FP16 model
offset = 0
for p in self.model.parameters():
for p in self.params:
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
self.fp32_optimizer.zero_grad()
for p in self.params:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
self._needs_sync = False
......@@ -183,6 +183,10 @@ def add_distributed_training_args(parser):
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--no-c10d', action='store_true',
help='don\'t use c10d distributed backend')
group.add_argument('--c10d-bucket-cap-mb', default=150, metavar='MB',
help='bucket size for c10d backend')
return group
......
This diff is collapsed.
......@@ -19,8 +19,10 @@ from train import main as single_process_main
def main(args):
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
args.distributed_init_method = 'tcp://localhost:{port}'.format(
port=random.randint(10000, 20000))
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_init_host = 'localhost'
args.distributed_port = port + 1
mp = torch.multiprocessing.get_context('spawn')
......
......@@ -35,7 +35,7 @@ bleu = Extension(
setup(
name='fairseq',
version='0.5.0',
version='0.6.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme,
license=license,
......
......@@ -16,7 +16,7 @@ import math
import torch
from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer
from fairseq.data import iterators
from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
......@@ -43,16 +43,17 @@ def main(args):
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Make a dummy batch to (i) warm the caching allocator and (ii) as a
# placeholder DistributedDataParallel when there's an uneven number of
# batches per worker.
max_positions = utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
# Build trainer
if args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16,'
' please switch to FP32 which is likely to be faster')
trainer = FP16Trainer(args, task, model, criterion)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, task, model, criterion)
trainer = Trainer(args, task, model, criterion, dummy_batch)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
......@@ -60,10 +61,6 @@ def main(args):
))
# Initialize dataloader
max_positions = utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
......@@ -78,9 +75,7 @@ def main(args):
# Load the latest checkpoint if one is available
if not load_checkpoint(args, trainer, epoch_itr):
# Send a dummy batch to warm the caching allocator
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
trainer.dummy_train_step([dummy_batch])
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
......@@ -110,32 +105,32 @@ def main(args):
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
# Update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
)
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples)
if log_output is None:
continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
......@@ -163,7 +158,9 @@ def train(args, trainer, task, epoch_itr):
progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip', 'gnorm']:
for k in [
'train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'gnorm', 'clip',
]:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
......@@ -230,7 +227,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
continue
extra_meters[k].update(v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment