Commit 4593ebfa authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix handling of partially-empty initial batch (#11)

parent 03c4a716
......@@ -10,6 +10,7 @@
Train a network on multiple GPUs using multiprocessing.
"""
from itertools import cycle, islice
import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
......@@ -48,6 +49,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
for rank in range(self.num_replicas)
])
self._grads_initialized = False
def _async_init(self, rank, device_id, args, model, nccl_uid):
"""Initialize child processes."""
self.args = args
......@@ -121,8 +124,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
"""Do forward, backward and gradient step in parallel."""
assert isinstance(criterion, FairseqCriterion)
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples = False
if not self._grads_initialized:
replace_empty_samples = True
self._grads_initialized = True
# scatter sample across GPUs
self._scatter_samples(samples)
self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
criterion.prepare(samples)
# forward pass, backward pass and gradient step
......@@ -234,10 +244,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.lr_scheduler.step(val_loss, epoch)
return self.optimizer.param_groups[0]['lr']
def _scatter_samples(self, samples, volatile=False):
def _scatter_samples(self, samples, volatile=False, replace_empty_samples=False):
"""Split and distribute a sample across GPUs."""
# Pad with None until its size is equal to the number of replicas.
samples = samples + [None]*(self.num_replicas - len(samples))
if not replace_empty_samples:
# pad with None until its size is equal to the number of replicas
samples = samples + [None]*(self.num_replicas - len(samples))
else:
# pad by cycling through the given samples
samples = list(islice(cycle(samples), self.num_replicas))
Future.gen_list([
self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment