Unverified Commit ce470ba3 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[checkpoint] sharded optim save/load grad scaler (#1350)

parent 05fae1fd
......@@ -363,7 +363,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.master_params[p].trans_state(TensorState.HOLD)
def state_dict(self):
optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict
return optim_state_dict
def load_state_dict(self, *args, **kwargs):
if 'scaler' not in args[0]:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
else:
scaler_state_dict = args[0].pop('scaler')
self.grad_scaler.load_state_dict(scaler_state_dict)
super().load_state_dict(*args, **kwargs)
for group in self.optim.param_groups:
for p in group['params']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment