Commit 2d27ae08 authored by Sergey Edunov's avatar Sergey Edunov Committed by Myle Ott
Browse files

Simulated big batches

parent 60c4081b
...@@ -176,6 +176,8 @@ def add_optimization_args(parser): ...@@ -176,6 +176,8 @@ def add_optimization_args(parser):
' dataset') ' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N', group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs') help='sort batches by source length for first N epochs')
group.add_argument('--update-freq', default=1, type=int, metavar='N',
help='update parameters every N batches')
return group return group
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
Train a network on multiple GPUs. Train a network on multiple GPUs.
""" """
from collections import OrderedDict from collections import defaultdict, OrderedDict
from itertools import chain
import math import math
import torch import torch
...@@ -55,6 +56,7 @@ class Trainer(object): ...@@ -55,6 +56,7 @@ class Trainer(object):
self.meters['clip'] = AverageMeter() # % of updates clipped self.meters['clip'] = AverageMeter() # % of updates clipped
self.meters['oom'] = AverageMeter() # out of memory self.meters['oom'] = AverageMeter() # out of memory
self._buffered_stats = defaultdict(lambda: [])
self._max_bsz_seen = 0 self._max_bsz_seen = 0
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
...@@ -86,22 +88,46 @@ class Trainer(object): ...@@ -86,22 +88,46 @@ class Trainer(object):
return extra_state return extra_state
def train_step(self, sample): def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
sample = self._prepare_sample(sample, volatile=False) sample = self._prepare_sample(sample, volatile=False)
# forward pass # forward and backward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample) loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss)
# buffer stats and logging outputs
self._buffered_stats['sample_sizes'].append(sample_size)
self._buffered_stats['logging_outputs'].append(logging_output)
self._buffered_stats['ooms_fwd'].append(oom_fwd)
self._buffered_stats['ooms_bwd'].append(oom_bwd)
# update parameters
if update_params:
# gather logging outputs from all GPUs
sample_sizes = self._buffered_stats['sample_sizes']
logging_outputs = self._buffered_stats['logging_outputs']
ooms_fwd = self._buffered_stats['ooms_fwd']
ooms_bwd = self._buffered_stats['ooms_bwd']
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map(
lambda l: list(chain.from_iterable(l)),
zip(*distributed_utils.all_gather_list(
(sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)
))
)
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
# aggregate stats and logging outputs # aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs)
# backward pass, all-reduce gradients and take an optimization step # all-reduce gradients and take an optimization step
grad_norm, ooms_bwd = self._backward_and_opt(loss, grad_denom) grad_denom = self.criterion.__class__.grad_denom(sample_sizes)
grad_norm = self._opt(grad_denom)
# update meters # update meters
self.meters['wps'].update(ntokens) self.meters['wps'].update(ntokens)
...@@ -119,7 +145,11 @@ class Trainer(object): ...@@ -119,7 +145,11 @@ class Trainer(object):
if 'nll_loss' in agg_logging_output: if 'nll_loss' in agg_logging_output:
self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens)
self._buffered_stats.clear()
return agg_logging_output return agg_logging_output
else:
return None # buffering updates
def _forward(self, sample, eval=False): def _forward(self, sample, eval=False):
# prepare model and optimizer # prepare model and optimizer
...@@ -127,7 +157,6 @@ class Trainer(object): ...@@ -127,7 +157,6 @@ class Trainer(object):
self.model.eval() self.model.eval()
else: else:
self.model.train() self.model.train()
self.optimizer.zero_grad()
loss = None loss = None
sample_size = 0 sample_size = 0
...@@ -152,19 +181,9 @@ class Trainer(object): ...@@ -152,19 +181,9 @@ class Trainer(object):
else: else:
raise e raise e
# synchronize logging outputs for multi-GPU training return loss, sample_size, logging_output, oom
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs, ooms = zip(*list(
distributed_utils.all_gather_list((sample_size, logging_output, oom))))
ooms = sum(ooms)
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
ooms = oom
return loss, sample_sizes, logging_outputs, ooms
def _backward_and_opt(self, loss, grad_denom): def _backward(self, loss):
oom = 0 oom = 0
if loss is not None: if loss is not None:
try: try:
...@@ -179,7 +198,9 @@ class Trainer(object): ...@@ -179,7 +198,9 @@ class Trainer(object):
self.optimizer.zero_grad() self.optimizer.zero_grad()
else: else:
raise e raise e
return oom
def _opt(self, grad_denom):
# all-reduce grads and rescale by grad_denom # all-reduce grads and rescale by grad_denom
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
grads = [p.grad.data for p in self.model.parameters() if p.requires_grad] grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
...@@ -197,12 +218,13 @@ class Trainer(object): ...@@ -197,12 +218,13 @@ class Trainer(object):
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad()
self._num_updates += 1 self._num_updates += 1
# update learning rate # update learning rate
self.lr_scheduler.step_update(self._num_updates) self.lr_scheduler.step_update(self._num_updates)
return grad_norm, oom return grad_norm
def valid_step(self, sample): def valid_step(self, sample):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
...@@ -210,8 +232,17 @@ class Trainer(object): ...@@ -210,8 +232,17 @@ class Trainer(object):
sample = self._prepare_sample(sample, volatile=True) sample = self._prepare_sample(sample, volatile=True)
# forward pass # forward pass
loss, sample_sizes, logging_outputs, ooms_fwd = self._forward(sample, eval=True) _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not ooms_fwd, 'Ran out of memory during validation' assert not oom_fwd, 'Ran out of memory during validation'
# gather logging outputs from all GPUs
if self.args.distributed_world_size > 1:
sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list(
(sample_size, logging_output)
))
else:
sample_sizes = [sample_size]
logging_outputs = [logging_output]
# aggregate stats and logging outputs # aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
......
...@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
) )
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple') progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
epoch_size = len(itr)
itr = itertools.islice(progress, batch_offset, None) itr = itertools.islice(progress, batch_offset, None)
# reset training meters # reset training meters
...@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset): for i, sample in enumerate(itr, start=batch_offset):
log_output = trainer.train_step(sample) if i < epoch_size - 1 and (i + 1) % args.update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats # log mid-epoch stats
stats = get_training_stats(trainer) stats = get_training_stats(trainer)
...@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset):
stats[k] = extra_meters[k].avg stats[k] = extra_meters[k].avg
progress.log(stats) progress.log(stats)
# save mid-epoch checkpoints
if i == batch_offset:
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
if i == batch_offset:
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
# save mid-epoch checkpoints # save mid-epoch checkpoints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment