Commit bafeed46 authored by Jerry Ma's avatar Jerry Ma Committed by Facebook Github Bot
Browse files

log more OOM sites (#893)

Summary:
- Adds memory summary logging to validation and optimization steps.
- Clarifies in the logging that optimization OOMs are not recoverable.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/893

Differential Revision: D18110763

Pulled By: jma127

fbshipit-source-id: 49340e611169c606ab9c991265167a79f51846e6
parent e23e5eaa
...@@ -318,24 +318,11 @@ class Trainer(object): ...@@ -318,24 +318,11 @@ class Trainer(object):
self._all_reduce_list[4] += logging_output.get('ntokens', 0.0) self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e): if 'out of memory' in str(e):
msg = ( self._log_oom(e)
'| WARNING: ran out of memory with exception: '
+ '{};'.format(e)
+ '\n Skipping batch'
)
# TODO: print should really go to logger, this print goes
# to stderr, which is buffered, which in many cases is not
# printed out if another exception happens.
# NB(jerry): added a flush to mitigate this
print(msg, file=sys.stderr)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx),
file=sys.stderr)
sys.stderr.flush()
if raise_oom: if raise_oom:
raise ValueError(msg) raise e
print("| WARNING: attempting to recover from OOM in forward/backward pass",
file=sys.stderr)
ooms += 1 ooms += 1
self.zero_grad() self.zero_grad()
else: else:
...@@ -455,6 +442,11 @@ class Trainer(object): ...@@ -455,6 +442,11 @@ class Trainer(object):
print('| WARNING: overflow detected, ' + str(e)) print('| WARNING: overflow detected, ' + str(e))
self.zero_grad() self.zero_grad()
logging_output = None logging_output = None
except RuntimeError as e:
if 'out of memory' in str(e):
self._log_oom(e)
print('| ERROR: OOM during optimization, irrecoverable')
raise e
if self.args.fp16: if self.args.fp16:
self.meters['loss_scale'].reset() self.meters['loss_scale'].reset()
...@@ -483,15 +475,16 @@ class Trainer(object): ...@@ -483,15 +475,16 @@ class Trainer(object):
sample, self.model, self.criterion sample, self.model, self.criterion
) )
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e) and not raise_oom: if 'out of memory' in str(e):
print('| WARNING: ran out of memory, retrying batch') self._log_oom(e)
if not raise_oom:
print('| WARNING: ran out of memory in validation step, retrying batch')
for p in self.model.parameters(): for p in self.model.parameters():
if p.grad is not None: if p.grad is not None:
p.grad = None # free some memory p.grad = None # free some memory
if self.cuda: if self.cuda:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True) return self.valid_step(sample, raise_oom=True)
else:
raise e raise e
if ignore_results: if ignore_results:
...@@ -621,3 +614,16 @@ class Trainer(object): ...@@ -621,3 +614,16 @@ class Trainer(object):
) )
) )
) )
def _log_oom(self, exc):
msg = '| OOM: Ran out of memory with exception: {}'.format(exc)
# TODO: print should really go to logger, this print goes
# to stderr, which is buffered, which in many cases is not
# printed out if another exception happens.
# NB(jerry): added a flush to mitigate this
print(msg, file=sys.stderr)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=device_idx),
file=sys.stderr)
sys.stderr.flush()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment