Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b21993b3
Unverified
Commit
b21993b3
authored
Jul 27, 2020
by
Gong Linyuan
Committed by
GitHub
Jul 27, 2020
Browse files
Allow to set Adam beta1, beta2 in TrainingArgs (#5592)
* Add Adam beta1, beta2 to trainier * Make style consistent
parent
7969e96f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
4 deletions
+21
-4
src/transformers/optimization_tf.py
src/transformers/optimization_tf.py
+11
-3
src/transformers/trainer.py
src/transformers/trainer.py
+6
-1
src/transformers/trainer_tf.py
src/transformers/trainer_tf.py
+2
-0
src/transformers/training_args.py
src/transformers/training_args.py
+2
-0
No files found.
src/transformers/optimization_tf.py
View file @
b21993b3
...
@@ -84,6 +84,8 @@ def create_optimizer(
...
@@ -84,6 +84,8 @@ def create_optimizer(
num_train_steps
:
int
,
num_train_steps
:
int
,
num_warmup_steps
:
int
,
num_warmup_steps
:
int
,
min_lr_ratio
:
float
=
0.0
,
min_lr_ratio
:
float
=
0.0
,
adam_beta1
:
float
=
0.9
,
adam_beta2
:
float
=
0.999
,
adam_epsilon
:
float
=
1e-8
,
adam_epsilon
:
float
=
1e-8
,
weight_decay_rate
:
float
=
0.0
,
weight_decay_rate
:
float
=
0.0
,
include_in_weight_decay
:
Optional
[
List
[
str
]]
=
None
,
include_in_weight_decay
:
Optional
[
List
[
str
]]
=
None
,
...
@@ -100,6 +102,10 @@ def create_optimizer(
...
@@ -100,6 +102,10 @@ def create_optimizer(
The number of warmup steps.
The number of warmup steps.
min_lr_ratio (:obj:`float`, `optional`, defaults to 0):
min_lr_ratio (:obj:`float`, `optional`, defaults to 0):
The final learning rate at the end of the linear decay will be :obj:`init_lr * min_lr_ratio`.
The final learning rate at the end of the linear decay will be :obj:`init_lr * min_lr_ratio`.
adam_beta1 (:obj:`float`, `optional`, defaults to 0.9):
The beta1 to use in Adam.
adam_beta2 (:obj:`float`, `optional`, defaults to 0.999):
The beta2 to use in Adam.
adam_epsilon (:obj:`float`, `optional`, defaults to 1e-8):
adam_epsilon (:obj:`float`, `optional`, defaults to 1e-8):
The epsilon to use in Adam.
The epsilon to use in Adam.
weight_decay_rate (:obj:`float`, `optional`, defaults to 0):
weight_decay_rate (:obj:`float`, `optional`, defaults to 0):
...
@@ -122,14 +128,16 @@ def create_optimizer(
...
@@ -122,14 +128,16 @@ def create_optimizer(
optimizer
=
AdamWeightDecay
(
optimizer
=
AdamWeightDecay
(
learning_rate
=
lr_schedule
,
learning_rate
=
lr_schedule
,
weight_decay_rate
=
weight_decay_rate
,
weight_decay_rate
=
weight_decay_rate
,
beta_1
=
0.9
,
beta_1
=
adam_beta1
,
beta_2
=
0.999
,
beta_2
=
adam_beta2
,
epsilon
=
adam_epsilon
,
epsilon
=
adam_epsilon
,
exclude_from_weight_decay
=
[
"LayerNorm"
,
"layer_norm"
,
"bias"
],
exclude_from_weight_decay
=
[
"LayerNorm"
,
"layer_norm"
,
"bias"
],
include_in_weight_decay
=
include_in_weight_decay
,
include_in_weight_decay
=
include_in_weight_decay
,
)
)
else
:
else
:
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
lr_schedule
,
epsilon
=
adam_epsilon
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
lr_schedule
,
beta_1
=
adam_beta1
,
beta_2
=
adam_beta2
,
epsilon
=
adam_epsilon
)
# We return the optimizer and the LR scheduler in order to better track the
# We return the optimizer and the LR scheduler in order to better track the
# evolution of the LR independently of the optimizer.
# evolution of the LR independently of the optimizer.
return
optimizer
,
lr_schedule
return
optimizer
,
lr_schedule
...
...
src/transformers/trainer.py
View file @
b21993b3
...
@@ -343,7 +343,12 @@ class Trainer:
...
@@ -343,7 +343,12 @@ class Trainer:
"weight_decay"
:
0.0
,
"weight_decay"
:
0.0
,
},
},
]
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
args
.
learning_rate
,
eps
=
self
.
args
.
adam_epsilon
)
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
self
.
args
.
learning_rate
,
betas
=
(
self
.
args
.
adam_beta1
,
self
.
args
.
adam_beta2
),
eps
=
self
.
args
.
adam_epsilon
,
)
scheduler
=
get_linear_schedule_with_warmup
(
scheduler
=
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
optimizer
,
num_warmup_steps
=
self
.
args
.
warmup_steps
,
num_training_steps
=
num_training_steps
)
)
...
...
src/transformers/trainer_tf.py
View file @
b21993b3
...
@@ -171,6 +171,8 @@ class TFTrainer:
...
@@ -171,6 +171,8 @@ class TFTrainer:
self
.
args
.
learning_rate
,
self
.
args
.
learning_rate
,
num_training_steps
,
num_training_steps
,
self
.
args
.
warmup_steps
,
self
.
args
.
warmup_steps
,
adam_beta1
=
self
.
args
.
adam_beta1
,
adam_beta2
=
self
.
args
.
adam_beta2
,
adam_epsilon
=
self
.
args
.
adam_epsilon
,
adam_epsilon
=
self
.
args
.
adam_epsilon
,
weight_decay_rate
=
self
.
args
.
weight_decay
,
weight_decay_rate
=
self
.
args
.
weight_decay
,
)
)
...
...
src/transformers/training_args.py
View file @
b21993b3
...
@@ -160,6 +160,8 @@ class TrainingArguments:
...
@@ -160,6 +160,8 @@ class TrainingArguments:
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for Adam."
})
learning_rate
:
float
=
field
(
default
=
5e-5
,
metadata
=
{
"help"
:
"The initial learning rate for Adam."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay if we apply some."
})
weight_decay
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Weight decay if we apply some."
})
adam_beta1
:
float
=
field
(
default
=
0.9
,
metadata
=
{
"help"
:
"Beta1 for Adam optimizer"
})
adam_beta2
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"Beta2 for Adam optimizer"
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for Adam optimizer."
})
adam_epsilon
:
float
=
field
(
default
=
1e-8
,
metadata
=
{
"help"
:
"Epsilon for Adam optimizer."
})
max_grad_norm
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"Max gradient norm."
})
max_grad_norm
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"Max gradient norm."
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment