Commit 91384a5a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix_grad_scalar_warning' into 'main'

fixed grad scalar warning so it only prints it for fp16

See merge request ADLR/megatron-lm!441
parents d2073912 27bc1338
......@@ -679,7 +679,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
if self.fp16 and 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment