Commit 97ba5c0e authored by mohammad's avatar mohammad
Browse files

load and save state dicts added

parent 0888a3e1
......@@ -25,7 +25,6 @@ def _get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
......@@ -48,7 +47,6 @@ def _get_params_for_weight_decay_optimization(module):
def get_megatron_optimizer(model):
args = get_args()
# Base optimizer.
......@@ -77,4 +75,4 @@ def get_megatron_optimizer(model):
args.clip_grad)
# FP32.
return FP32Optimizer(optimizer, model, args.clip_grad)
return FP32Optimizer(optimizer, args.clip_grad)
......@@ -40,7 +40,6 @@ class MegatronGradScaler(ABC):
def update(self, found_inf):
pass
'''
@abstractmethod
def state_dict(self):
pass
......@@ -48,7 +47,7 @@ class MegatronGradScaler(ABC):
@abstractmethod
def load_state_dict(self, state_dict):
pass
'''
class ConstantGradScaler(MegatronGradScaler):
......@@ -56,6 +55,13 @@ class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
......@@ -111,3 +117,17 @@ class DynamicGradScaler(MegatronGradScaler):
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
......@@ -145,7 +145,6 @@ class MegatronOptimizer(ABC):
def step(self):
pass
'''
@abstractmethod
def state_dict(self):
pass
......@@ -153,7 +152,6 @@ class MegatronOptimizer(ABC):
@abstractmethod
def load_state_dict(self, state_dict):
pass
'''
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
......@@ -180,7 +178,6 @@ class MegatronOptimizer(ABC):
class FP16OptimizerWithFP16Params(MegatronOptimizer):
def __init__(self, optimizer, grad_scaler, clip_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer)
......@@ -369,12 +366,32 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return True
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
return state_dict
def load_state_dict(self, state_dict):
# Defer to the class to load.
self.optimizer.load_state_dict(state_dict['optimizer'])
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Copy data for the master params.
for current_group, saved_group in zip(
self.fp32_from_fp16_groups,
state_dict['fp32_from_fp16_params']):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, model, clip_grad):
def __init__(self, optimizer, clip_grad):
super(FP32Optimizer, self).__init__(optimizer)
self.model = model
self.clip_grad = clip_grad
self._scale = torch.cuda.FloatTensor([1.0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment