Commit 4593ebfa authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix handling of partially-empty initial batch (#11)

parent 03c4a716
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
Train a network on multiple GPUs using multiprocessing. Train a network on multiple GPUs using multiprocessing.
""" """
from itertools import cycle, islice
import torch import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
...@@ -48,6 +49,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -48,6 +49,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
for rank in range(self.num_replicas) for rank in range(self.num_replicas)
]) ])
self._grads_initialized = False
def _async_init(self, rank, device_id, args, model, nccl_uid): def _async_init(self, rank, device_id, args, model, nccl_uid):
"""Initialize child processes.""" """Initialize child processes."""
self.args = args self.args = args
...@@ -121,8 +124,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -121,8 +124,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
"""Do forward, backward and gradient step in parallel.""" """Do forward, backward and gradient step in parallel."""
assert isinstance(criterion, FairseqCriterion) assert isinstance(criterion, FairseqCriterion)
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples = False
if not self._grads_initialized:
replace_empty_samples = True
self._grads_initialized = True
# scatter sample across GPUs # scatter sample across GPUs
self._scatter_samples(samples) self._scatter_samples(samples, replace_empty_samples=replace_empty_samples)
criterion.prepare(samples) criterion.prepare(samples)
# forward pass, backward pass and gradient step # forward pass, backward pass and gradient step
...@@ -234,10 +244,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -234,10 +244,14 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.lr_scheduler.step(val_loss, epoch) self.lr_scheduler.step(val_loss, epoch)
return self.optimizer.param_groups[0]['lr'] return self.optimizer.param_groups[0]['lr']
def _scatter_samples(self, samples, volatile=False): def _scatter_samples(self, samples, volatile=False, replace_empty_samples=False):
"""Split and distribute a sample across GPUs.""" """Split and distribute a sample across GPUs."""
# Pad with None until its size is equal to the number of replicas. if not replace_empty_samples:
# pad with None until its size is equal to the number of replicas
samples = samples + [None]*(self.num_replicas - len(samples)) samples = samples + [None]*(self.num_replicas - len(samples))
else:
# pad by cycling through the given samples
samples = list(islice(cycle(samples), self.num_replicas))
Future.gen_list([ Future.gen_list([
self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile) self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment