Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dbfe34f2
Unverified
Commit
dbfe34f2
authored
Aug 27, 2020
by
Stas Bekman
Committed by
GitHub
Aug 27, 2020
Browse files
[test schedulers] adjust to test the first step's reading (#6429)
* [test schedulers] small improvement * cleanup
parent
e6b811f0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
15 deletions
+10
-15
tests/test_optimization.py
tests/test_optimization.py
+10
-15
No files found.
tests/test_optimization.py
View file @
dbfe34f2
...
...
@@ -40,16 +40,16 @@ if is_torch_available():
def
unwrap_schedule
(
scheduler
,
num_steps
=
10
):
lrs
=
[]
for
_
in
range
(
num_steps
):
lrs
.
append
(
scheduler
.
get_lr
()[
0
])
scheduler
.
step
()
lrs
.
append
(
scheduler
.
get_lr
())
return
lrs
def
unwrap_and_save_reload_schedule
(
scheduler
,
num_steps
=
10
):
lrs
=
[]
for
step
in
range
(
num_steps
):
lrs
.
append
(
scheduler
.
get_lr
()[
0
])
scheduler
.
step
()
lrs
.
append
(
scheduler
.
get_lr
())
if
step
==
num_steps
//
2
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
file_name
=
os
.
path
.
join
(
tmpdirname
,
"schedule.bin"
)
...
...
@@ -127,23 +127,23 @@ class ScheduleInitTest(unittest.TestCase):
get_constant_schedule
:
({},
[
10.0
]
*
self
.
num_steps
),
get_constant_schedule_with_warmup
:
(
{
"num_warmup_steps"
:
4
},
[
2.5
,
5.0
,
7.5
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
],
[
0.0
,
2.5
,
5.0
,
7.5
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
,
10.0
],
),
get_linear_schedule_with_warmup
:
(
{
**
common_kwargs
},
[
5.0
,
10.0
,
8.75
,
7.5
,
6.25
,
5.0
,
3.75
,
2.5
,
1.25
,
0.0
],
[
0.0
,
5.0
,
10.0
,
8.75
,
7.5
,
6.25
,
5.0
,
3.75
,
2.5
,
1.25
],
),
get_cosine_schedule_with_warmup
:
(
{
**
common_kwargs
},
[
5.0
,
10.0
,
9.61
,
8.53
,
6.91
,
5.0
,
3.08
,
1.46
,
0.38
,
0.0
],
[
0.0
,
5.0
,
10.0
,
9.61
,
8.53
,
6.91
,
5.0
,
3.08
,
1.46
,
0.38
],
),
get_cosine_with_hard_restarts_schedule_with_warmup
:
(
{
**
common_kwargs
,
"num_cycles"
:
2
},
[
5.0
,
10.0
,
8.53
,
5.0
,
1.46
,
10.0
,
8.53
,
5.0
,
1.46
,
0.0
],
[
0.0
,
5.0
,
10.0
,
8.53
,
5.0
,
1.46
,
10.0
,
8.53
,
5.0
,
1.46
],
),
get_polynomial_decay_schedule_with_warmup
:
(
{
**
common_kwargs
,
"power"
:
2.0
,
"lr_end"
:
1e-7
},
[
5.0
,
10.0
,
7.656
,
5.625
,
3.906
,
2.5
,
1.406
,
0.625
,
0.156
,
1e-07
],
[
0.0
,
5.0
,
10.0
,
7.656
,
5.625
,
3.906
,
2.5
,
1.406
,
0.625
,
0.156
],
),
}
...
...
@@ -151,17 +151,12 @@ class ScheduleInitTest(unittest.TestCase):
kwargs
,
expected_learning_rates
=
data
scheduler
=
scheduler_func
(
self
.
optimizer
,
**
kwargs
)
self
.
assertEqual
(
len
([
scheduler
.
get_lr
()[
0
]]),
1
)
lrs_1
=
unwrap_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertEqual
(
len
(
lrs_1
[
0
]),
1
)
self
.
assertListAlmostEqual
(
[
l
[
0
]
for
l
in
lrs_1
],
expected_learning_rates
,
tol
=
1e-2
,
msg
=
f
"failed for
{
scheduler_func
}
in normal scheduler"
,
lrs_1
,
expected_learning_rates
,
tol
=
1e-2
,
msg
=
f
"failed for
{
scheduler_func
}
in normal scheduler"
,
)
scheduler
=
scheduler_func
(
self
.
optimizer
,
**
kwargs
)
lrs_2
=
unwrap_and_save_reload_schedule
(
scheduler
,
self
.
num_steps
)
self
.
assertListEqual
(
[
l
[
0
]
for
l
in
lrs_1
],
[
l
[
0
]
for
l
in
lrs_2
],
msg
=
f
"failed for
{
scheduler_func
}
in save and reload"
)
self
.
assertListEqual
(
lrs_1
,
lrs_2
,
msg
=
f
"failed for
{
scheduler_func
}
in save and reload"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment