Commit 6c006a34 authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Take a dummy train step under OOM to keep multiprocessing in sync

Summary: This is not a guaranteed solution (since processes may still get out of sync if OOM happens after an all_gather/all_reduce has been done) - but should still make multiprocessing training more robust in practice since it seems we usually OOM early enough.

Reviewed By: myleott

Differential Revision: D13086018

fbshipit-source-id: feb1b01c2eb8818797cfdabc0faac8056ba1b4ee
parent ccd22212
...@@ -29,7 +29,7 @@ class Trainer(object): ...@@ -29,7 +29,7 @@ class Trainer(object):
communication of the gradients across workers. communication of the gradients across workers.
""" """
def __init__(self, args, task, model, criterion, dummy_batch): def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
...@@ -45,6 +45,7 @@ class Trainer(object): ...@@ -45,6 +45,7 @@ class Trainer(object):
self._model = model.cuda() self._model = model.cuda()
self._dummy_batch = dummy_batch self._dummy_batch = dummy_batch
self._oom_batch = oom_batch
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
self._optimizer = None self._optimizer = None
...@@ -198,6 +199,9 @@ class Trainer(object): ...@@ -198,6 +199,9 @@ class Trainer(object):
else: else:
raise e raise e
if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)
if dummy_batch: if dummy_batch:
return None return None
...@@ -331,6 +335,15 @@ class Trainer(object): ...@@ -331,6 +335,15 @@ class Trainer(object):
self.train_step(dummy_batch, dummy_batch=True) self.train_step(dummy_batch, dummy_batch=True)
self.zero_grad() self.zero_grad()
def handle_ooms(self, number_of_ooms):
"""
c10d accumulates/syncs gradients between gpus during backward pass.
In case of OOMs, gpus may fail to sync, so we manually iterate
extra to make sure each gpu makes same number of iterations.
"""
for _ in range(number_of_ooms):
self.train_step([self._oom_batch], True)
def zero_grad(self): def zero_grad(self):
self.optimizer.zero_grad() self.optimizer.zero_grad()
......
...@@ -51,9 +51,10 @@ def main(args): ...@@ -51,9 +51,10 @@ def main(args):
model.max_positions(), model.max_positions(),
) )
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
# Build trainer # Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch) trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
print('| training on {} GPUs'.format(args.distributed_world_size)) print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format( print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens, args.max_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment