"vscode:/vscode.git/clone" did not exist on "5b972fbd6a6c50cf1afdf1ba34c34d84fc67861c"
Commit d63c2541 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'fix_grad_scalar_warning' into 'main'

fixed grad scalar warning for bf16

See merge request ADLR/megatron-lm!442
parents 91384a5a aaa5715a
...@@ -449,6 +449,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -449,6 +449,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Grad scaler. # Grad scaler.
if 'grad_scaler' not in state_dict: if 'grad_scaler' not in state_dict:
if self.fp16:
print_rank_0('***WARNING*** found an old checkpoint, will not ' print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...') 'load grad scaler ...')
else: else:
......
...@@ -679,7 +679,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -679,7 +679,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.optimizer.load_state_dict(state_dict[optimizer_key]) self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler. # Grad scaler.
if self.fp16 and 'grad_scaler' not in state_dict: if 'grad_scaler' not in state_dict:
if self.fp16:
print_rank_0('***WARNING*** found an old checkpoint, will not ' print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...') 'load grad scaler ...')
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment