Commit 7381754e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'optimize_refactor_reload_model_params' into 'optimize_refactor'

Optimize refactor reload model params

See merge request ADLR/megatron-lm!199
parents 43529f78 9eedf896
...@@ -45,6 +45,18 @@ def _zero_grad_group_helper(group, set_to_none): ...@@ -45,6 +45,18 @@ def _zero_grad_group_helper(group, set_to_none):
param.grad.zero_() param.grad.zero_()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""Use multi-tensor-applier to copy values from one list to another."""
if overflow_buf:
overflow_buf.fill_(0)
else:
overflow_buf = torch.cuda.IntTensor([0])
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
overflow_buf,
[this, that],
1.0)
class MegatronOptimizer(ABC): class MegatronOptimizer(ABC):
...@@ -76,6 +88,10 @@ class MegatronOptimizer(ABC): ...@@ -76,6 +88,10 @@ class MegatronOptimizer(ABC):
def step(self): def step(self):
pass pass
@abstractmethod
def reload_model_params(self):
pass
@abstractmethod @abstractmethod
def state_dict(self): def state_dict(self):
pass pass
...@@ -123,7 +139,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -123,7 +139,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# ====================== # ======================
# master parameter stuff # main parameter stuff
# ====================== # ======================
# Three groups of parameters: # Three groups of parameters:
...@@ -147,20 +163,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -147,20 +163,20 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
if param.type() == 'torch.cuda.HalfTensor': if param.type() == 'torch.cuda.HalfTensor':
fp16_params_this_group.append(param) fp16_params_this_group.append(param)
# Create a copy # Create a copy
master_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Store grads # Store grads
master_param.requires_grad = True main_param.requires_grad = True
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(master_param, mpu.copy_tensor_model_parallel_attributes(main_param,
param) param)
if hasattr(param, 'shared'): if hasattr(param, 'shared'):
master_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = master_param param_group['params'][i] = main_param
fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_params_this_group.append(main_param)
# Reset existing state dict key to the new master param. # Reset existing state dict key to the new main param.
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[master_param] \ self.optimizer.state[main_param] \
= self.optimizer.state.pop(param) = self.optimizer.state.pop(param)
# fp32 params. # fp32 params.
...@@ -196,43 +212,39 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -196,43 +212,39 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return self.grad_scaler.scale return self.grad_scaler.scale
def _copy_model_grads_to_master_grads(self): def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the fp16 group. # This only needs to be done for the fp16 group.
model_grads = [] model_grads = []
master_grads = [] main_grads = []
for model_group, master_group in zip(self.fp16_groups, for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups): self.fp32_from_fp16_groups):
for model_param, master_param in zip(model_group, master_group): for model_param, main_param in zip(model_group, main_group):
if model_param.grad is not None: if model_param.grad is not None:
if master_param.grad is None: if main_param.grad is None:
master_param.grad = torch.empty_like(master_param) main_param.grad = torch.empty_like(main_param)
model_grads.append(model_param.grad.data) model_grads.append(model_param.grad.data)
master_grads.append(master_param.grad.data) main_grads.append(main_param.grad.data)
self._dummy_overflow_buf.fill_(0) _multi_tensor_copy_this_to_that(this=model_grads, that=main_grads,
# Scaling with factor `1.0` is equivalent to copy. overflow_buf=self._dummy_overflow_buf)
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[model_grads, master_grads],
1.0)
def _unscale_master_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
master_grads = [] main_grads = []
# fp32 params fromm fp16 ones. # fp32 params fromm fp16 ones.
for master_group in self.fp32_from_fp16_groups: for main_group in self.fp32_from_fp16_groups:
for master_param in master_group: for main_param in main_group:
if master_param.grad is not None: if main_param.grad is not None:
master_grads.append(master_param.grad.data) main_grads.append(main_param.grad.data)
# Append fp32 parameters. # Append fp32 parameters.
for master_group in self.fp32_from_fp32_groups: for main_group in self.fp32_from_fp32_groups:
for master_param in master_group: for main_param in main_group:
if master_param.grad is not None: if main_param.grad is not None:
master_grads.append(master_param.grad.data) main_grads.append(main_param.grad.data)
# Reset found inf. # Reset found inf.
self.found_inf.fill_(0.0) self.found_inf.fill_(0.0)
# Unscale and set found inf/nan # Unscale and set found inf/nan
torch._amp_foreach_non_finite_check_and_unscale_( torch._amp_foreach_non_finite_check_and_unscale_(
master_grads, self.found_inf, self.grad_scaler.inv_scale) main_grads, self.found_inf, self.grad_scaler.inv_scale)
# Update across all model parallel instances. # Update across all model parallel instances.
torch.distributed.all_reduce(self.found_inf, torch.distributed.all_reduce(self.found_inf,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
...@@ -243,21 +255,33 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -243,21 +255,33 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return found_inf_flag return found_inf_flag
def _copy_master_params_to_model_params(self): def _get_model_and_main_params_data_fp16(self):
# Only needed for the fp16 params.
model_data = [] model_data = []
master_data = [] main_data = []
for model_group, master_group in zip(self.fp16_groups, for model_group, main_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups): self.fp32_from_fp16_groups):
for model_param, master_param in zip(model_group, master_group): for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data) model_data.append(model_param.data)
master_data.append(master_param.data) main_data.append(main_param.data)
self._dummy_overflow_buf.fill_(0) return model_data, main_data
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf, def _copy_main_params_to_model_params(self):
[master_data, model_data], # Only needed for the fp16 params.
1.0) model_data, main_data = self._get_model_and_main_params_data_fp16()
_multi_tensor_copy_this_to_that(this=main_data, that=model_data,
overflow_buf=self._dummy_overflow_buf)
def _copy_model_params_to_main_params(self):
# Only needed for the fp16 params.
model_data, main_data = self._get_model_and_main_params_data_fp16()
_multi_tensor_copy_this_to_that(this=model_data, that=main_data,
overflow_buf=self._dummy_overflow_buf)
def reload_model_params(self):
self._copy_model_params_to_main_params()
@torch.no_grad() @torch.no_grad()
...@@ -266,17 +290,17 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -266,17 +290,17 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
timers = get_timers() timers = get_timers()
# ================================================== # ==================================================
# Copy gradients from model params to master params. # Copy gradients from model params to main params.
# ================================================== # ==================================================
timers('optimizer-copy-to-master-grad').start() timers('optimizer-copy-to-main-grad').start()
self._copy_model_grads_to_master_grads() self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-master-grad').stop() timers('optimizer-copy-to-main-grad').stop()
# ============================== # ==============================
# Unscale and check for inf/nan. # Unscale and check for inf/nan.
# ============================== # ==============================
timers('optimizer-unscale-and-check-inf').start() timers('optimizer-unscale-and-check-inf').start()
found_inf_flag = self._unscale_master_grads_and_check_for_nan() found_inf_flag = self._unscale_main_grads_and_check_for_nan()
timers('optimizer-unscale-and-check-inf').stop() timers('optimizer-unscale-and-check-inf').stop()
# ================================== # ==================================
...@@ -292,11 +316,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -292,11 +316,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
return False return False
# ========================== # ==========================
# Clip the master gradients. # Clip the main gradients.
# ========================== # ==========================
timers('optimizer-clip-master-grad').start() timers('optimizer-clip-main-grad').start()
self.clip_grad_norm(self.clip_grad) self.clip_grad_norm(self.clip_grad)
timers('optimizer-clip-master-grad').stop() timers('optimizer-clip-main-grad').stop()
# =================== # ===================
# Step the optimizer. # Step the optimizer.
...@@ -304,11 +328,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -304,11 +328,11 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
self.optimizer.step() self.optimizer.step()
# ================================= # =================================
# Update params from master params. # Update params from main params.
# ================================= # =================================
timers('optimizer-copy-master-to-model-params').start() timers('optimizer-copy-main-to-model-params').start()
self._copy_master_params_to_model_params() self._copy_main_params_to_model_params()
timers('optimizer-copy-master-to-model-params').stop() timers('optimizer-copy-main-to-model-params').stop()
# ================== # ==================
# Successful update. # Successful update.
...@@ -340,7 +364,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer): ...@@ -340,7 +364,7 @@ class FP16OptimizerWithFP16Params(MegatronOptimizer):
else: else:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Copy data for the master params. # Copy data for the main params.
fp32_from_fp16_params_key = 'fp32_from_fp16_params' fp32_from_fp16_params_key = 'fp32_from_fp16_params'
if fp32_from_fp16_params_key not in state_dict: if fp32_from_fp16_params_key not in state_dict:
fp32_from_fp16_params_key = 'fp32_from_fp16' fp32_from_fp16_params_key = 'fp32_from_fp16'
...@@ -388,6 +412,10 @@ class FP32Optimizer(MegatronOptimizer): ...@@ -388,6 +412,10 @@ class FP32Optimizer(MegatronOptimizer):
return True return True
def reload_model_params(self):
pass
def state_dict(self): def state_dict(self):
return self.optimizer.state_dict() return self.optimizer.state_dict()
......
...@@ -677,10 +677,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -677,10 +677,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-send-forward-recv') add_to_logging('backward-send-forward-recv')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
add_to_logging('optimizer-copy-to-master-grad') add_to_logging('optimizer-copy-to-main-grad')
add_to_logging('optimizer-unscale-and-check-inf') add_to_logging('optimizer-unscale-and-check-inf')
add_to_logging('optimizer-clip-master-grad') add_to_logging('optimizer-clip-main-grad')
add_to_logging('optimizer-copy-master-to-model-params') add_to_logging('optimizer-copy-main-to-model-params')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch-generator') add_to_logging('batch-generator')
......
...@@ -255,9 +255,8 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -255,9 +255,8 @@ def finetune(train_valid_datasets_provider, model_provider,
_ = load_checkpoint(model, None, None) _ = load_checkpoint(model, None, None)
args.load = original_load args.load = original_load
# This is critical when only model is loaded. We should make sure # This is critical when only model is loaded. We should make sure
# master parameters are also updated. # main parameters are also updated.
if args.fp16: optimizer.reload_model_params()
optimizer._model_params_to_master_params()
timers('pretrained checkpoint').stop() timers('pretrained checkpoint').stop()
# Print setup timing. # Print setup timing.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment