Commit bd315c35 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/empty-cache' into 'main'

Flag to call empty_cache() each iteration, to reduce fragmentation

See merge request ADLR/megatron-lm!306
parents 5ca20cdd 52b2296b
...@@ -601,6 +601,11 @@ def _add_distributed_args(parser): ...@@ -601,6 +601,11 @@ def _add_distributed_args(parser):
group.add_argument('--use-cpu-initialization', action='store_true', group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights ' default=None, help='If set, affine parallel weights '
'initialization uses CPU' ) 'initialization uses CPU' )
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
return parser return parser
......
...@@ -362,6 +362,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -362,6 +362,10 @@ def train_step(forward_step_func, data_iterator,
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory
if args.empty_unused_memory_each_iter >= 1:
torch.cuda.empty_cache()
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
...@@ -408,6 +412,10 @@ def train_step(forward_step_func, data_iterator, ...@@ -408,6 +412,10 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# Empty unused memory
if args.empty_unused_memory_each_iter >= 2:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True): if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches. # Average loss across microbatches.
loss_reduced = {} loss_reduced = {}
...@@ -716,6 +724,10 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -716,6 +724,10 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
forward_step_func, data_iterator, model, optimizer=None, forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True) timers=None, forward_only=True)
# Empty unused memory
if args.empty_unused_memory_each_iter >= 1:
torch.cuda.empty_cache()
if mpu.is_pipeline_last_stage(ignore_virtual=True): if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes. # Reduce across processes.
for loss_dict in loss_dicts: for loss_dict in loss_dicts:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment