Commit 315c463d authored by Jerry Ma's avatar Jerry Ma Committed by Facebook Github Bot
Browse files

Add periodic CUDA cache cleanup (#882)

Summary:
This adds a periodic call to `torch.cuda.empty_cache()` in order to
mitigate memory fragmentation in the PyTorch CUDA cached allocator
that can cause OOMs on models approaching GPU memory limit.
By default, this will occur every 64 updates.

Performance considerations:

- I've benchmarked this on a reasonably large model with memory
  footprint 16 GB, and the overhead with the default setting is <0.2%.
  With `update-freq > 1`, the cost is mitigated even further.
- This behavior can be disabled with a value of zero.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/882

Differential Revision: D17742386

Pulled By: jma127

fbshipit-source-id: 68d8f93f798d6818b5efc3d67d43b52dfb8b2865
parent de348d1f
...@@ -193,6 +193,8 @@ def get_parser(desc, default_task='translation'): ...@@ -193,6 +193,8 @@ def get_parser(desc, default_task='translation'):
help='threshold FP16 loss scale from below') help='threshold FP16 loss scale from below')
parser.add_argument('--user-dir', default=None, parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)') help='path to a python module containing custom extensions (tasks and/or architectures)')
parser.add_argument('--empty-cache-freq', default=0, type=int,
help='how often to clear the PyTorch CUDA cache (0 to disable)')
from fairseq.registry import REGISTRIES from fairseq.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items(): for registry_name, REGISTRY in REGISTRIES.items():
......
...@@ -426,6 +426,14 @@ class Trainer(object): ...@@ -426,6 +426,14 @@ class Trainer(object):
if 'nll_loss' in logging_output: if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
# clear CUDA cache to reduce memory fragmentation
if (self.args.empty_cache_freq > 0 and
((self.get_num_updates() + self.args.empty_cache_freq - 1) %
self.args.empty_cache_freq) == 0 and
torch.cuda.is_available() and
not self.args.cpu):
torch.cuda.empty_cache()
except OverflowError as e: except OverflowError as e:
print('| WARNING: overflow detected, ' + str(e)) print('| WARNING: overflow detected, ' + str(e))
self.zero_grad() self.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment