Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0740e63e
Commit
0740e63e
authored
Jul 23, 2019
by
thomwolf
Browse files
updating schedules for state_dict saving
parent
268c6cc1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
31 deletions
+71
-31
pytorch_transformers/optimization.py
pytorch_transformers/optimization.py
+36
-30
pytorch_transformers/tests/optimization_test.py
pytorch_transformers/tests/optimization_test.py
+35
-1
No files found.
pytorch_transformers/optimization.py
View file @
0740e63e
...
...
@@ -36,13 +36,13 @@ class WarmupConstantSchedule(LambdaLR):
Keeps learning rate schedule equal to 1. after warmup_steps.
"""
def
__init__
(
self
,
optimizer
,
warmup_steps
,
last_epoch
=-
1
):
self
.
warmup_steps
=
warmup_steps
super
(
WarmupConstantSchedule
,
self
).
__init__
(
optimizer
,
self
.
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1.0
,
warmup_steps
))
return
1.
super
(
WarmupConstantSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
self
,
step
):
if
step
<
self
.
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1.0
,
self
.
warmup_steps
))
return
1.
class
WarmupLinearSchedule
(
LambdaLR
):
...
...
@@ -51,13 +51,14 @@ class WarmupLinearSchedule(LambdaLR):
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
"""
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
last_epoch
=-
1
):
self
.
warmup_steps
=
warmup_steps
self
.
t_total
=
t_total
super
(
WarmupLinearSchedule
,
self
).
__init__
(
optimizer
,
self
.
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1
,
warmup_steps
))
return
max
(
0.0
,
float
(
t_total
-
step
)
/
float
(
max
(
1.0
,
t_total
-
warmup_steps
)))
super
(
WarmupLinearSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
self
,
step
):
if
step
<
self
.
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1
,
self
.
warmup_steps
))
return
max
(
0.0
,
float
(
self
.
t_total
-
step
)
/
float
(
max
(
1.0
,
self
.
t_total
-
self
.
warmup_steps
)))
class
WarmupCosineSchedule
(
LambdaLR
):
...
...
@@ -66,17 +67,19 @@ class WarmupCosineSchedule(LambdaLR):
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
"""
warn_t_total
=
True
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
cycles
=
.
5
,
last_epoch
=-
1
):
self
.
warmup_steps
=
warmup_steps
self
.
t_total
=
t_total
self
.
cycles
=
cycles
super
(
WarmupCosineSchedule
,
self
).
__init__
(
optimizer
,
self
.
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1.0
,
warmup_steps
))
else
:
progress
=
float
(
step
-
warmup_steps
)
/
float
(
max
(
1
,
t_total
-
warmup_steps
))
# progress after warmup
return
max
(
0.0
,
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
float
(
cycles
)
*
2.0
*
progress
)))
def
lr_lambda
(
self
,
step
):
if
step
<
self
.
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1.0
,
self
.
warmup_steps
))
# progress after warmup
progress
=
float
(
step
-
self
.
warmup_steps
)
/
float
(
max
(
1
,
self
.
t_total
-
self
.
warmup_steps
))
return
max
(
0.0
,
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
float
(
self
.
cycles
)
*
2.0
*
progress
)))
super
(
WarmupCosineSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupCosineWithHardRestartsSchedule
(
LambdaLR
):
""" Linear warmup and then cosine cycles with hard restarts.
...
...
@@ -85,17 +88,20 @@ class WarmupCosineWithHardRestartsSchedule(LambdaLR):
learning rate (with hard restarts).
"""
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
cycles
=
1.
,
last_epoch
=-
1
):
self
.
warmup_steps
=
warmup_steps
self
.
t_total
=
t_total
self
.
cycles
=
cycles
super
(
WarmupCosineWithHardRestartsSchedule
,
self
).
__init__
(
optimizer
,
self
.
lr_lambda
,
last_epoch
=
last_epoch
)
def
lr_lambda
(
self
,
step
):
if
step
<
self
.
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1
,
self
.
warmup_steps
))
# progress after warmup
progress
=
float
(
step
-
self
.
warmup_steps
)
/
float
(
max
(
1
,
self
.
t_total
-
self
.
warmup_steps
))
if
progress
>=
1.0
:
return
0.0
return
max
(
0.0
,
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
float
(
self
.
cycles
)
*
progress
)
%
1.0
))))
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
float
(
step
)
/
float
(
max
(
1
,
warmup_steps
))
else
:
progress
=
float
(
step
-
warmup_steps
)
/
float
(
max
(
1
,
t_total
-
warmup_steps
))
# progress after warmup
if
progress
>=
1.0
:
return
0.0
return
max
(
0.0
,
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
float
(
cycles
)
*
progress
)
%
1.0
))))
super
(
WarmupCosineWithHardRestartsSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
AdamW
(
Optimizer
):
...
...
pytorch_transformers/tests/optimization_test.py
View file @
0740e63e
...
...
@@ -17,13 +17,14 @@ from __future__ import division
from
__future__
import
print_function
import
unittest
import
os
import
torch
from
pytorch_transformers
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
import
numpy
as
np
from
.tokenization_tests_commons
import
TemporaryDirectory
def
unwrap_schedule
(
scheduler
,
num_steps
=
10
):
...
...
@@ -33,6 +34,20 @@ def unwrap_schedule(scheduler, num_steps=10):
lrs
.
append
(
scheduler
.
get_lr
())
return
lrs
def
unwrap_and_save_reload_schedule
(
scheduler
,
num_steps
=
10
):
lrs
=
[]
for
step
in
range
(
num_steps
):
scheduler
.
step
()
lrs
.
append
(
scheduler
.
get_lr
())
if
step
==
num_steps
//
2
:
with
TemporaryDirectory
()
as
tmpdirname
:
file_name
=
os
.
path
.
join
(
tmpdirname
,
'schedule.bin'
)
torch
.
save
(
scheduler
.
state_dict
(),
file_name
)
state_dict
=
torch
.
load
(
file_name
)
scheduler
.
load_state_dict
(
state_dict
)
return
lrs
class
OptimizationTest
(
unittest
.
TestCase
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
...
...
@@ -72,6 +87,10 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
scheduler
=
ConstantLRSchedule
(
self
.
optimizer
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_constant_scheduler
(
self
):
scheduler
=
WarmupConstantSchedule
(
self
.
optimizer
,
warmup_steps
=
4
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
...
...
@@ -79,6 +98,10 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
scheduler
=
WarmupConstantSchedule
(
self
.
optimizer
,
warmup_steps
=
4
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_linear_scheduler
(
self
):
scheduler
=
WarmupLinearSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
...
...
@@ -86,6 +109,10 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
)
scheduler
=
WarmupLinearSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_cosine_scheduler
(
self
):
scheduler
=
WarmupCosineSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
...
...
@@ -93,6 +120,10 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
scheduler
=
WarmupCosineSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
t_total
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
def
test_warmup_cosine_hard_restart_scheduler
(
self
):
scheduler
=
WarmupCosineWithHardRestartsSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
cycles
=
2
,
t_total
=
10
)
lrs
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
...
...
@@ -100,6 +131,9 @@ class ScheduleInitTest(unittest.TestCase):
self
.
assertEqual
(
len
(
lrs
[
0
]),
1
)
self
.
assertListAlmostEqual
([
l
[
0
]
for
l
in
lrs
],
expected_learning_rates
,
tol
=
1e-2
)
scheduler
=
WarmupCosineWithHardRestartsSchedule
(
self
.
optimizer
,
warmup_steps
=
2
,
cycles
=
2
,
t_total
=
10
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
([
l
[
0
]
for
l
in
lrs
],
[
l
[
0
]
for
l
in
lrs_2
])
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment