Unverified Commit c63e2701 authored by Shijie Wu's avatar Shijie Wu Committed by GitHub
Browse files

refactor decay_parameters production into its own function (#26152)

parent 77ed9fa1
......@@ -951,6 +951,17 @@ class Trainer:
optimizer = self.optimizer
self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
def get_decay_parameter_names(self, model) -> List[str]:
"""
Get all parameter names that weight decay will be applied to
Note that some models implement their own layernorm instead of calling nn.LayerNorm, weight decay could still
apply to those modules since this function only filter out instance of nn.LayerNorm
"""
decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
return decay_parameters
def create_optimizer(self):
"""
Setup the optimizer.
......@@ -961,8 +972,7 @@ class Trainer:
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment