Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8ca767f1
Commit
8ca767f1
authored
Jul 15, 2019
by
thomwolf
Browse files
clean up optimization
parent
74a24f0f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
59 deletions
+49
-59
pytorch_transformers/optimization.py
pytorch_transformers/optimization.py
+49
-59
No files found.
pytorch_transformers/optimization.py
View file @
8ca767f1
...
...
@@ -24,18 +24,47 @@ from torch.optim.lr_scheduler import LambdaLR
logger
=
logging
.
getLogger
(
__name__
)
class
ConstantLRSchedule
(
LambdaLR
):
""" Constant learning rate schedule.
"""
def
__init__
(
self
,
optimizer
,
last_epoch
=-
1
):
super
(
ConstantLRSchedule
,
self
).
__init__
(
optimizer
,
lambda
_
:
1.0
,
last_epoch
=
last_epoch
)
class
WarmupCosineSchedule
(
LambdaLR
):
class
WarmupConstantSchedule
(
LambdaLR
):
""" Linear warmup and then constant.
Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
Keeps learning rate schedule equal to 1. after warmup_steps.
"""
def
__init__
(
self
,
optimizer
,
warmup_steps
,
last_epoch
=-
1
):
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1.0
,
warmup_steps
))
return
1.
super
(
WarmupConstantSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupLinearSchedule
(
LambdaLR
):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
Linearly increases learning rate from 0 to 1 over `warmup` training steps.
Decreases learning rate from 1. to 0. over remaining `t_total - warmup` steps following a cosine curve.
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
last_epoch
=-
1
):
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1
,
warmup_steps
))
return
max
(
0.0
,
float
(
t_total
-
step
)
/
float
(
max
(
1.0
,
t_total
-
warmup_steps
)))
super
(
WarmupLinearSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupCosineSchedule
(
LambdaLR
):
""" Linear warmup and then cosine decay.
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
warn_t_total
=
True
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
cycles
=
.
5
,
last_epoch
=-
1
):
...
...
@@ -45,13 +74,13 @@ class WarmupCosineSchedule(LambdaLR):
return
float
(
step
)
/
float
(
max
(
1.0
,
warmup_steps
))
else
:
progress
=
float
(
step
-
warmup_steps
)
/
float
(
max
(
1
,
t_total
-
warmup_steps
))
# progress after warmup
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
float
(
cycles
)
*
2.0
*
progress
))
return
max
(
0.0
,
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
float
(
cycles
)
*
2.0
*
progress
))
)
super
(
WarmupCosineSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupCosineWithHardRestartsSchedule
(
LambdaLR
):
"""
Linearly increases learning rate from 0 to 1 over `warmup
` fraction of
training steps.
"""
Linear warmup and then cosine cycles with hard restarts.
Linearly increases learning rate from 0 to 1 over `warmup
_steps`
training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
"""
...
...
@@ -64,69 +93,30 @@ class WarmupCosineWithHardRestartsSchedule(LambdaLR):
progress
=
float
(
step
-
warmup_steps
)
/
float
(
max
(
1
,
t_total
-
warmup_steps
))
# progress after warmup
if
progress
>=
1.0
:
return
0.0
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
float
(
cycles
)
*
progress
)
%
1.0
)))
return
max
(
0.0
,
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
float
(
cycles
)
*
progress
)
%
1.0
)))
)
super
(
WarmupCosineWithHardRestartsSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupConstantSchedule
(
LambdaLR
):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Keeps learning rate equal to 1. after warmup.
"""
def
__init__
(
self
,
optimizer
,
warmup_steps
,
last_epoch
=-
1
):
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1.0
,
warmup_steps
))
return
1.
super
(
WarmupConstantSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupLinearSchedule
(
LambdaLR
):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
"""
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
last_epoch
=-
1
):
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1
,
warmup_steps
))
return
float
(
t_total
-
step
)
/
float
(
max
(
1.0
,
t_total
-
warmup_steps
))
super
(
WarmupLinearSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
AdamW
(
Optimizer
):
""" Implements Adam algorithm with weight decay fix.
Parameters:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
schedule: schedule to use for the warmup (see above).
Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
If `None` or `'none'`, learning rate is always kept constant.
Default : `'warmup_linear'`
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
correct_bias: can be set to False to avoid correcting bias in Adam (e.g. like in Bert repository)
lr (float): learning rate. Default 1e-3.
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
eps (float): Adams epsilon. Default: 1e-6
weight_decay (float): Weight decay. Default: 0.0
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.0
1
,
correct_bias
=
True
):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.0
,
correct_bias
=
True
):
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter: {} - should be in [0.0, 1.0["
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter: {} - should be in [0.0, 1.0["
.
format
(
betas
[
1
]
))
raise
ValueError
(
"Invalid beta parameter: {} - should be in [0.0, 1.0["
.
format
(
betas
[
1
]))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
ps
))
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
correct_bias
=
correct_bias
)
super
(
AdamW
,
self
).
__init__
(
params
,
defaults
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment