Commit 97ba5c0e authored by mohammad's avatar mohammad
Browse files

load and save state dicts added

parent 0888a3e1
...@@ -25,7 +25,6 @@ def _get_params_for_weight_decay_optimization(module): ...@@ -25,7 +25,6 @@ def _get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will. Layernorms and baises will have no weight decay but the rest will.
""" """
args = get_args() args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection) LayerNorm = import_layernorm(args.fp32_residual_connection)
...@@ -48,7 +47,6 @@ def _get_params_for_weight_decay_optimization(module): ...@@ -48,7 +47,6 @@ def _get_params_for_weight_decay_optimization(module):
def get_megatron_optimizer(model): def get_megatron_optimizer(model):
args = get_args() args = get_args()
# Base optimizer. # Base optimizer.
...@@ -77,4 +75,4 @@ def get_megatron_optimizer(model): ...@@ -77,4 +75,4 @@ def get_megatron_optimizer(model):
args.clip_grad) args.clip_grad)
# FP32. # FP32.
return FP32Optimizer(optimizer, model, args.clip_grad) return FP32Optimizer(optimizer, args.clip_grad)
...@@ -40,7 +40,6 @@ class MegatronGradScaler(ABC): ...@@ -40,7 +40,6 @@ class MegatronGradScaler(ABC):
def update(self, found_inf): def update(self, found_inf):
pass pass
'''
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
pass pass
...@@ -48,7 +47,7 @@ class MegatronGradScaler(ABC): ...@@ -48,7 +47,7 @@ class MegatronGradScaler(ABC):
@abstractmethod @abstractmethod
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
pass pass
'''
class ConstantGradScaler(MegatronGradScaler): class ConstantGradScaler(MegatronGradScaler):
...@@ -56,6 +55,13 @@ class ConstantGradScaler(MegatronGradScaler): ...@@ -56,6 +55,13 @@ class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf): def update(self, found_inf):
pass pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler): class DynamicGradScaler(MegatronGradScaler):
...@@ -111,3 +117,17 @@ class DynamicGradScaler(MegatronGradScaler): ...@@ -111,3 +117,17 @@ class DynamicGradScaler(MegatronGradScaler):
self._hysteresis_tracker = self.hysteresis self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale. # and scale up the loss scale.
self._scale = self._scale * self.growth_factor self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
...@@ -145,7 +145,6 @@ class MegatronOptimizer(ABC): ...@@ -145,7 +145,6 @@ class MegatronOptimizer(ABC):
def step(self): def step(self):
pass pass
'''
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
pass pass
...@@ -153,7 +152,6 @@ class MegatronOptimizer(ABC): ...@@ -153,7 +152,6 @@ class MegatronOptimizer(ABC):
@abstractmethod @abstractmethod
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
pass pass
'''
# Promote state so it can be retrieved or set via # Promote state so it can be retrieved or set via
# "optimizer_instance.state" # "optimizer_instance.state"
...@@ -180,7 +178,6 @@ class MegatronOptimizer(ABC): ...@@ -180,7 +178,6 @@ class MegatronOptimizer(ABC):
class FP16OptimizerWithFP16Params(MegatronOptimizer): class FP16OptimizerWithFP16Params(MegatronOptimizer):
def __init__(self, optimizer, grad_scaler, clip_grad): def __init__(self, optimizer, grad_scaler, clip_grad):
super(FP16OptimizerWithFP16Params, self).__init__(optimizer) super(FP16OptimizerWithFP16Params, self).__init__(optimizer)
...@@ -369,12 +366,32 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -369,12 +366,32 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return True return True
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
return state_dict
def load_state_dict(self, state_dict):
# Defer to the class to load.
self.optimizer.load_state_dict(state_dict['optimizer'])
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Copy data for the master params.
for current_group, saved_group in zip(
self.fp32_from_fp16_groups,
state_dict['fp32_from_fp16_params']):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
class FP32Optimizer(MegatronOptimizer): class FP32Optimizer(MegatronOptimizer):
def __init__(self, optimizer, model, clip_grad): def __init__(self, optimizer, clip_grad):
super(FP32Optimizer, self).__init__(optimizer) super(FP32Optimizer, self).__init__(optimizer)
self.model = model
self.clip_grad = clip_grad self.clip_grad = clip_grad
self._scale = torch.cuda.FloatTensor([1.0]) self._scale = torch.cuda.FloatTensor([1.0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment