Commit 160ba680 authored by mohammad's avatar mohammad
Browse files

added reload model params for finetuning

parent 43529f78
......@@ -76,6 +76,10 @@ class MegatronOptimizer(ABC):
def step(self):
pass
@abstractmethod
def reload_model_params(self):
pass
@abstractmethod
def state_dict(self):
pass
......@@ -243,8 +247,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return found_inf_flag
def _copy_master_params_to_model_params(self):
# Only needed for the fp16 params.
def _get_model_and_master_params_data_fp16(self):
model_data = []
master_data = []
for model_group, master_group in zip(self.fp16_groups,
......@@ -252,6 +255,12 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
for model_param, master_param in zip(model_group, master_group):
model_data.append(model_param.data)
master_data.append(master_param.data)
return model_data, master_data
def _copy_master_params_to_model_params(self):
# Only needed for the fp16 params.
model_data, master_data = self._get_model_and_master_params_data_fp16()
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
......@@ -259,6 +268,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
[master_data, model_data],
1.0)
def _copy_model_params_to_master_params(self):
# Only needed for the fp16 params.
model_data, master_data = self._get_model_and_master_params_data_fp16()
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[model_data, master_data],
1.0)
def reload_model_params(self):
self._copy_model_params_to_master_params()
@torch.no_grad()
def step(self):
......@@ -388,6 +411,10 @@ class FP32Optimizer(MegatronOptimizer):
return True
def reload_model_params(self):
pass
def state_dict(self):
return self.optimizer.state_dict()
......
......@@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
if args.fp16:
optimizer._model_params_to_master_params()
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment