"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "105b77fe346e7a1267e8319073a9353a1b45f395"
Commit 2d27ae08 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Simulated big batches

parent 60c4081b
......@@ -176,6 +176,8 @@ def add_optimization_args(parser):
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
group.add_argument('--update-freq', default=1, type=int, metavar='N',
help='update parameters every N batches')
return group
......
......@@ -9,7 +9,8 @@
Train a network on multiple GPUs.
"""
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from itertools import chain
import math
import torch
......@@ -55,6 +56,7 @@ class Trainer(object):
self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory
self._buffered_stats = defaultdict(lambda: [])
self._max_bsz_seen = 0
self._num_updates = 0
self._optim_history = None
......@@ -86,40 +88,68 @@ class Trainer(object):
return extra_state
def train_step(self, sample):
def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update."""
sample = self._prepare_sample(sample, volatile=False)
# forward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample)
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# backward pass, all-reduce gradients and take an optimization step
grad_norm, ooms_bwd = self._backward_and_opt(loss, grad_denom)
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
return agg_logging_output
# forward and backward pass
loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params:
# gather logging outputs from all GPUs
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# all-reduce gradients and take an optimization step
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norm = self._opt(grad_denom)
# update meters
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.)
self.meters['oom'].update(ooms_fwd + ooms_bwd)
# update loss meters for training
if 'loss' in agg_logging_output:
self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom)
# criterions can optionally log the NLL loss too
if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
self._buffered_stats.clear()
return agg_logging_output
else:
return None # buffering updates
def _forward(self, sample, eval=False):
# prepare model and optimizer
......@@ -127,7 +157,6 @@ class Trainer(object):
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
loss = None
sample_size = 0
......@@ -152,19 +181,9 @@ class Trainer(object):
else:
raise e
# synchronize logging outputs for multi-GPU training
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms = zip(*list(
distributed_utils.all_gather_list((sample_size, logging_output, oom))))
ooms = sum(ooms)
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
ooms = oom
return loss, sample_sizes, logging_outputs, ooms
return loss, sample_size, logging_output, oom
def _backward_and_opt(self, loss, grad_denom):
def _backward(self, loss):
oom = 0
if loss is not None:
try:
......@@ -179,7 +198,9 @@ class Trainer(object):
self.optimizer.zero_grad()
else:
raise e
return oom
def _opt(self, grad_denom):
# all-reduce grads and rescale by grad_denom
if self.args.distributed_world_size > 1:
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
......@@ -197,12 +218,13 @@ class Trainer(object):
# take an optimization step
self.optimizer.step()
self.optimizer.zero_grad()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
return grad_norm, oom
return grad_norm
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
......@@ -210,8 +232,17 @@ class Trainer(object):
sample = self._prepare_sample(sample, volatile=True)
# forward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample, eval=True)
assert not ooms_fwd, 'Ran out of memory during validation'
_loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output)
))
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
......
......@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
epoch_size = len(itr)
itr = itertools.islice(progress, batch_offset, None)
# reset training meters
......@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset):
log_output = trainer.train_step(sample)
if i < epoch_size - 1 and (i + 1) % args.update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats
stats = get_training_stats(trainer)
......@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset):
stats[k] = extra_meters[k].avg
progress.log(stats)
# save mid-epoch checkpoints
# ignore the first mini-batch in words-per-second calculation
if i == batch_offset:
# ignore the first mini-batch in words-per-second calculation
trainer.get_meter('wps').reset()
# save mid-epoch checkpoints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment