"docs/vscode:/vscode.git/clone" did not exist on "61fba9da0fcae0ba5f3fc426442302e9fef0443c"
Commit a6155337 authored by Myle Ott's avatar Myle Ott
Browse files

Better training support when GPUs are in "exclusive mode"

parent a8bc4d0a
......@@ -105,7 +105,6 @@ class LanguageDatasets(object):
return torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=PaddingCollater(self.src_dict.pad()),
batch_sampler=batch_sampler)
......
......@@ -122,14 +122,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
assert isinstance(criterion, FairseqCriterion)
# scatter sample across GPUs
samples, data_events = self._scatter_samples(samples)
self._scatter_samples(samples)
criterion.prepare(samples)
# forward pass, backward pass and gradient step
losses = [
self.call_async(rank, '_async_train_step', sample=samples[rank],
criterion=criterion, data_event=event)
for rank, event in enumerate(data_events)
self.call_async(rank, '_async_train_step', criterion=criterion)
for rank in range(self.num_replicas)
]
# aggregate losses and gradient norms
......@@ -138,8 +137,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return loss, grad_norms[0]
def _async_train_step(self, rank, device_id, sample, criterion, data_event):
data_event.wait()
def _async_train_step(self, rank, device_id, criterion):
self.model.train()
# zero grads even if net_input is None, since we will all-reduce them
......@@ -147,9 +145,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# calculate loss and grads
loss = 0
if sample is not None:
net_output = self.model(**sample['net_input'])
loss_ = criterion(net_output, sample)
if self._sample is not None:
net_output = self.model(**self._sample['net_input'])
loss_ = criterion(net_output, self._sample)
loss_.backward()
loss = loss_.data[0]
......@@ -191,14 +189,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def valid_step(self, samples, criterion):
"""Do forward pass in parallel."""
# scatter sample across GPUs
samples, data_events = self._scatter_samples(samples, volatile=True)
self._scatter_samples(samples, volatile=True)
criterion.prepare(samples)
# forward pass
losses = [
self.call_async(rank, '_async_valid_step', sample=samples[rank],
criterion=criterion, data_event=event)
for rank, event in enumerate(data_events)
self.call_async(rank, '_async_valid_step', criterion=criterion)
for rank in range(self.num_replicas)
]
# aggregate losses
......@@ -206,14 +203,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return loss
def _async_valid_step(self, rank, device_id, sample, criterion, data_event):
if sample is None:
def _async_valid_step(self, rank, device_id, criterion):
if self._sample is None:
return 0
data_event.wait()
self.model.eval()
net_output = self.model(**sample['net_input'])
loss = criterion(net_output, sample)
net_output = self.model(**self._sample['net_input'])
loss = criterion(net_output, self._sample)
return loss.data[0]
def get_lr(self):
......@@ -241,20 +236,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _scatter_samples(self, samples, volatile=False):
"""Split and distribute a sample across GPUs."""
res = [utils.prepare_sample(sample, volatile=volatile,
cuda_device=device_id)
for sample, device_id in zip(samples, self.device_ids)]
# Pad with None until its size is equal to the number of replicas.
res = res + [None]*(self.num_replicas - len(samples))
# Synchronize GPU devices after data is sent to prevent
# race conditions.
events = []
for d in self.device_ids:
with torch.cuda.device(d):
event = torch.cuda.Event(interprocess=True)
event.record()
events.append(event)
return res, events
samples = samples + [None]*(self.num_replicas - len(samples))
Future.gen_list([
self.call_async(rank, '_async_prepare_sample', sample=samples[rank], volatile=volatile)
for rank in range(self.num_replicas)
])
def _async_prepare_sample(self, rank, device_id, sample, volatile):
if sample is None:
self._sample = None
else:
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment