Commit aaa5715a authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Jared Casper
Browse files

fixed grad scalar warning for bf16

parent 91384a5a
...@@ -449,8 +449,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -449,8 +449,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Grad scaler. # Grad scaler.
if 'grad_scaler' not in state_dict: if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not ' if self.fp16:
'load grad scaler ...') print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else: else:
if self.grad_scaler: if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
......
...@@ -679,9 +679,10 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -679,9 +679,10 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.optimizer.load_state_dict(state_dict[optimizer_key]) self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler. # Grad scaler.
if self.fp16 and 'grad_scaler' not in state_dict: if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not ' if self.fp16:
'load grad scaler ...') print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else: else:
if self.grad_scaler: if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment