Commit a2901f98 authored by Yongqiang Wang's avatar Yongqiang Wang Committed by Facebook Github Bot
Browse files

an option to raise exception if oom happens during fairseq.trainer.train_step (#2)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairspeq/pull/2

Pull Request resolved: https://github.com/pytorch/fairseq/pull/689

We found not raising OOM during trainer.train_step causes various
issue, including NCCL hangs / gloo sync errors because gradient is not synced
properly. Before we found the root cause, let's give users an option to raise
OOMs.

Reviewed By: jmp84

Differential Revision: D15170357

fbshipit-source-id: 3e15e4e111a8380612157955509c39821a216ec4
parent f5fbcaaf
......@@ -13,6 +13,7 @@ from collections import OrderedDict
from itertools import chain
import math
import os
import sys
import torch
......@@ -174,7 +175,7 @@ class Trainer(object):
return extra_state
def train_step(self, samples, dummy_batch=False):
def train_step(self, samples, dummy_batch=False, raise_oom=False):
"""Do forward, backward and parameter update."""
self._set_seed()
self.model.train()
......@@ -219,7 +220,18 @@ class Trainer(object):
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
print(('| WARNING: ran out of memory with exception: {};\n Skipping batch').format(str(e)))
msg = (
'| WARNING: ran out of memory with exception: '
+ '{};'.format(e)
+ '\n Skipping batch'
)
# TODO: print should really go to logger, this print goes
# to stdout, which is buffered, which in many case is not
# printed out if another exception happens
# print(msg)
print(msg, file=sys.stderr)
if raise_oom:
raise ValueError(msg)
ooms += 1
self.zero_grad()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment