Commit 63b6b3f4 authored by Jerry Ma's avatar Jerry Ma Committed by Facebook Github Bot
Browse files

Add printing of PyTorch memory summary on OOM (#885)

Summary:
PyTorch now has more comprehensive memory instrumentation, added in https://github.com/pytorch/pytorch/pull/27361 . This PR makes fairseq print a summary table of the memory state when an OOM occurs.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/885

Differential Revision: D17820445

Pulled By: jma127

fbshipit-source-id: 1887417c7648d703f78e1cff9f2a5b89901f49d0
parent 34e79c58
...@@ -313,10 +313,16 @@ class Trainer(object): ...@@ -313,10 +313,16 @@ class Trainer(object):
+ '\n Skipping batch' + '\n Skipping batch'
) )
# TODO: print should really go to logger, this print goes # TODO: print should really go to logger, this print goes
# to stdout, which is buffered, which in many case is not # to stderr, which is buffered, which in many cases is not
# printed out if another exception happens # printed out if another exception happens.
# print(msg) # NB(jerry): added a flush to mitigate this
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
print(torch.cuda.memory_summary(device=torch.cuda.device(device_idx)),
file=sys.stderr)
sys.stderr.flush()
if raise_oom: if raise_oom:
raise ValueError(msg) raise ValueError(msg)
ooms += 1 ooms += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment