Commit 160ba680 authored by mohammad's avatar mohammad
Browse files

added reload model params for finetuning

parent 43529f78
...@@ -76,6 +76,10 @@ class MegatronOptimizer(ABC): ...@@ -76,6 +76,10 @@ class MegatronOptimizer(ABC):
def step(self): def step(self):
pass pass
@abstractmethod
def reload_model_params(self):
pass
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
pass pass
...@@ -243,8 +247,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -243,8 +247,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return found_inf_flag return found_inf_flag
def _copy_master_params_to_model_params(self): def _get_model_and_master_params_data_fp16(self):
# Only needed for the fp16 params.
model_data = [] model_data = []
master_data = [] master_data = []
for model_group, master_group in zip(self.fp16_groups, for model_group, master_group in zip(self.fp16_groups,
...@@ -252,6 +255,12 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -252,6 +255,12 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
for model_param, master_param in zip(model_group, master_group): for model_param, master_param in zip(model_group, master_group):
model_data.append(model_param.data) model_data.append(model_param.data)
master_data.append(master_param.data) master_data.append(master_param.data)
return model_data, master_data
def _copy_master_params_to_model_params(self):
# Only needed for the fp16 params.
model_data, master_data = self._get_model_and_master_params_data_fp16()
self._dummy_overflow_buf.fill_(0) self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy. # Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale, multi_tensor_applier(amp_C.multi_tensor_scale,
...@@ -259,6 +268,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -259,6 +268,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
[master_data, model_data], [master_data, model_data],
1.0) 1.0)
def _copy_model_params_to_master_params(self):
# Only needed for the fp16 params.
model_data, master_data = self._get_model_and_master_params_data_fp16()
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[model_data, master_data],
1.0)
def reload_model_params(self):
self._copy_model_params_to_master_params()
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
...@@ -388,6 +411,10 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -388,6 +411,10 @@ class FP32Optimizer(MegatronOptimizer):
return True return True
def reload_model_params(self):
pass
def state_dict(self): def state_dict(self):
return self.optimizer.state_dict() return self.optimizer.state_dict()
......
...@@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
args.load = original_load args.load = original_load
# This is critical when only model is loaded. We should make sure # This is critical when only model is loaded. We should make sure
# master parameters are also updated. # master parameters are also updated.
if args.fp16: optimizer.reload_model_params()
optimizer._model_params_to_master_params()
timers('pretrained checkpoint').stop() timers('pretrained checkpoint').stop()
# Print setup timing. # Print setup timing.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment