Commit a6155337 authored by Myle Ott's avatar Myle Ott
Browse files

Better training support when GPUs are in "exclusive mode"

parent a8bc4d0a
...@@ -105,7 +105,6 @@ class LanguageDatasets(object): ...@@ -105,7 +105,6 @@ class LanguageDatasets(object):
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=PaddingCollater(self.src_dict.pad()), collate_fn=PaddingCollater(self.src_dict.pad()),
batch_sampler=batch_sampler) batch_sampler=batch_sampler)
......
...@@ -122,14 +122,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -122,14 +122,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
assert isinstance(criterion, FairseqCriterion) assert isinstance(criterion, FairseqCriterion)
# scatter sample across GPUs # scatter sample across GPUs
samples, data_events = self._scatter_samples(samples) self._scatter_samples(samples)
criterion.prepare(samples) criterion.prepare(samples)
# forward pass, backward pass and gradient step # forward pass, backward pass and gradient step
losses = [ losses = [
self.call_async(rank, '_async_train_step', sample=samples[rank], self.call_async(rank, '_async_train_step', criterion=criterion)
criterion=criterion, data_event=event) for rank in range(self.num_replicas)
for rank, event in enumerate(data_events)
] ]
# aggregate losses and gradient norms # aggregate losses and gradient norms
...@@ -138,8 +137,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -138,8 +137,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return loss, grad_norms[0] return loss, grad_norms[0]
def _async_train_step(self, rank, device_id, sample, criterion, data_event): def _async_train_step(self, rank, device_id, criterion):
data_event.wait()
self.model.train() self.model.train()
# zero grads even if net_input is None, since we will all-reduce them # zero grads even if net_input is None, since we will all-reduce them
...@@ -147,9 +145,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -147,9 +145,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# calculate loss and grads # calculate loss and grads
loss = 0 loss = 0
if sample is not None: if self._sample is not None:
net_output = self.model(**sample['net_input']) net_output = self.model(**self._sample['net_input'])
loss_ = criterion(net_output, sample) loss_ = criterion(net_output, self._sample)
loss_.backward() loss_.backward()
loss = loss_.data[0] loss = loss_.data[0]
...@@ -191,14 +189,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -191,14 +189,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def valid_step(self, samples, criterion): def valid_step(self, samples, criterion):
"""Do forward pass in parallel.""" """Do forward pass in parallel."""
# scatter sample across GPUs # scatter sample across GPUs
samples, data_events = self._scatter_samples(samples, volatile=True) self._scatter_samples(samples, volatile=True)
criterion.prepare(samples) criterion.prepare(samples)
# forward pass # forward pass
losses = [ losses = [
self.call_async(rank, '_async_valid_step', sample=samples[rank], self.call_async(rank, '_async_valid_step', criterion=criterion)
criterion=criterion, data_event=event) for rank in range(self.num_replicas)
for rank, event in enumerate(data_events)
] ]
# aggregate losses # aggregate losses
...@@ -206,14 +203,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -206,14 +203,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return loss return loss
def _async_valid_step(self, rank, device_id, sample, criterion, data_event): def _async_valid_step(self, rank, device_id, criterion):
if sample is None: if self._sample is None:
return 0 return 0
data_event.wait()
self.model.eval() self.model.eval()
net_output = self.model(**sample['net_input']) net_output = self.model(**self._sample['net_input'])
loss = criterion(net_output, sample) loss = criterion(net_output, self._sample)
return loss.data[0] return loss.data[0]
def get_lr(self): def get_lr(self):
...@@ -241,20 +236,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -241,20 +236,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _scatter_samples(self, samples, volatile=False): def _scatter_samples(self, samples, volatile=False):
"""Split and distribute a sample across GPUs.""" """Split and distribute a sample across GPUs."""
res = [utils.prepare_sample(sample, volatile=volatile,
cuda_device=device_id)
for sample, device_id in zip(samples, self.device_ids)]
# Pad with None until its size is equal to the number of replicas. # Pad with None until its size is equal to the number of replicas.
res = res + [None]*(self.num_replicas - len(samples)) samples = samples + [None]*(self.num_replicas - len(samples))
# Synchronize GPU devices after data is sent to prevent Future.gen_list([
# race conditions. self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile)
events = [] for rank in range(self.num_replicas)
for d in self.device_ids: ])
with torch.cuda.device(d):
event = torch.cuda.Event(interprocess=True) def _async_prepare_sample(self, rank, device_id, sample, volatile):
event.record() if sample is None:
events.append(event) self._sample = None
else:
return res, events self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment