Unverified Commit 60de910e authored by Cola's avatar Cola Committed by GitHub
Browse files

Add `power` argument for TF PolynomialDecay (#5732)

* 🚩 Add `power` argument for TF PolynomialDecay

* 🚩 Create default optimizer with power

* 🚩 Add argument to training args

* 🚨 Clean code format

* 🚨 Fix black warning

* 🚨 Fix code format
parent 41c3a3b9
...@@ -88,6 +88,7 @@ def create_optimizer( ...@@ -88,6 +88,7 @@ def create_optimizer(
adam_beta2: float = 0.999, adam_beta2: float = 0.999,
adam_epsilon: float = 1e-8, adam_epsilon: float = 1e-8,
weight_decay_rate: float = 0.0, weight_decay_rate: float = 0.0,
power: float = 1.0,
include_in_weight_decay: Optional[List[str]] = None, include_in_weight_decay: Optional[List[str]] = None,
): ):
""" """
...@@ -110,6 +111,8 @@ def create_optimizer( ...@@ -110,6 +111,8 @@ def create_optimizer(
The epsilon to use in Adam. The epsilon to use in Adam.
weight_decay_rate (:obj:`float`, `optional`, defaults to 0): weight_decay_rate (:obj:`float`, `optional`, defaults to 0):
The weight decay to use. The weight decay to use.
power (:obj:`float`, `optional`, defaults to 1.0):
The power to use for PolynomialDecay.
include_in_weight_decay (:obj:`List[str]`, `optional`): include_in_weight_decay (:obj:`List[str]`, `optional`):
List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
applied to all parameters except bias and layer norm parameters. applied to all parameters except bias and layer norm parameters.
...@@ -119,6 +122,7 @@ def create_optimizer( ...@@ -119,6 +122,7 @@ def create_optimizer(
initial_learning_rate=init_lr, initial_learning_rate=init_lr,
decay_steps=num_train_steps - num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps,
end_learning_rate=init_lr * min_lr_ratio, end_learning_rate=init_lr * min_lr_ratio,
power=power,
) )
if num_warmup_steps: if num_warmup_steps:
lr_schedule = WarmUp( lr_schedule = WarmUp(
......
...@@ -227,6 +227,7 @@ class TFTrainer: ...@@ -227,6 +227,7 @@ class TFTrainer:
adam_beta2=self.args.adam_beta2, adam_beta2=self.args.adam_beta2,
adam_epsilon=self.args.adam_epsilon, adam_epsilon=self.args.adam_epsilon,
weight_decay_rate=self.args.weight_decay, weight_decay_rate=self.args.weight_decay,
power=self.args.poly_power,
) )
def setup_wandb(self): def setup_wandb(self):
......
...@@ -112,6 +112,11 @@ class TFTrainingArguments(TrainingArguments): ...@@ -112,6 +112,11 @@ class TFTrainingArguments(TrainingArguments):
metadata={"help": "Name of TPU"}, metadata={"help": "Name of TPU"},
) )
poly_power: float = field(
default=1.0,
metadata={"help": "Power for the Polynomial decay LR scheduler."},
)
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"}) xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
@cached_property @cached_property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment