Commit 6c006a34 authored by Halil Akin's avatar Halil Akin Committed by Facebook Github Bot
Browse files

Take a dummy train step under OOM to keep multiprocessing in sync

Summary: This is not a guaranteed solution (since processes may still get out of sync if OOM happens after an all_gather/all_reduce has been done) - but should still make multiprocessing training more robust in practice since it seems we usually OOM early enough.

Reviewed By: myleott

Differential Revision: D13086018

fbshipit-source-id: feb1b01c2eb8818797cfdabc0faac8056ba1b4ee
parent ccd22212
......@@ -29,7 +29,7 @@ class Trainer(object):
communication of the gradients across workers.
"""
def __init__(self, args, task, model, criterion, dummy_batch):
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
......@@ -45,6 +45,7 @@ class Trainer(object):
self._model = model.cuda()
self._dummy_batch = dummy_batch
self._oom_batch = oom_batch
self._num_updates = 0
self._optim_history = None
self._optimizer = None
......@@ -198,6 +199,9 @@ class Trainer(object):
else:
raise e
if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)
if dummy_batch:
return None
......@@ -331,6 +335,15 @@ class Trainer(object):
self.train_step(dummy_batch, dummy_batch=True)
self.zero_grad()
def handle_ooms(self, number_of_ooms):
"""
c10d accumulates/syncs gradients between gpus during backward pass.
In case of OOMs, gpus may fail to sync, so we manually iterate
extra to make sure each gpu makes same number of iterations.
"""
for _ in range(number_of_ooms):
self.train_step([self._oom_batch], True)
def zero_grad(self):
self.optimizer.zero_grad()
......
......@@ -51,9 +51,10 @@ def main(args):
model.max_positions(),
)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
oom_batch = task.dataset('train').get_dummy_batch(1, max_positions)
# Build trainer
trainer = Trainer(args, task, model, criterion, dummy_batch)
trainer = Trainer(args, task, model, criterion, dummy_batch, oom_batch)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment