Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
91a073f8
Commit
91a073f8
authored
Apr 03, 2019
by
lukovnikov
Browse files
schedule fix
parent
b64cc63a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
13 deletions
+10
-13
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+7
-10
tests/optimization_test.py
tests/optimization_test.py
+3
-3
No files found.
pytorch_pretrained_bert/optimization.py
View file @
91a073f8
...
@@ -38,11 +38,12 @@ class LRSchedule(object):
...
@@ -38,11 +38,12 @@ class LRSchedule(object):
:param kw:
:param kw:
"""
"""
super
(
LRSchedule
,
self
).
__init__
(
**
kw
)
super
(
LRSchedule
,
self
).
__init__
(
**
kw
)
self
.
warmup
,
self
.
t_total
=
warmup
,
t_total
if
t_total
<=
0
:
if
t_total
<=
0
:
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
warmup
=
max
(
warmup
,
0
)
self
.
warmup
,
self
.
t_total
=
warmup
,
t_total
self
.
warned_for_t_total_at_progress
=
-
1
self
.
warned_for_t_total_at_progress
=
-
1
def
get_lr
(
self
,
step
,
nowarn
=
False
):
def
get_lr
(
self
,
step
,
nowarn
=
False
):
...
@@ -51,6 +52,8 @@ class LRSchedule(object):
...
@@ -51,6 +52,8 @@ class LRSchedule(object):
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
:return: learning rate multiplier for current update
"""
"""
if
self
.
t_total
<
0
:
return
1.
progress
=
step
/
self
.
t_total
progress
=
step
/
self
.
t_total
ret
=
self
.
get_lr_
(
progress
)
ret
=
self
.
get_lr_
(
progress
)
# warning for exceeding t_total (only active with warmup_linear
# warning for exceeding t_total (only active with warmup_linear
...
@@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule):
...
@@ -87,9 +90,6 @@ class WarmupCosineSchedule(LRSchedule):
self
.
cycles
=
cycles
self
.
cycles
=
cycles
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
""" get learning rate multiplier """
if
self
.
t_total
<=
0
:
return
1.
if
progress
<
self
.
warmup
:
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
progress
/
self
.
warmup
else
:
else
:
...
@@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
...
@@ -106,8 +106,6 @@ class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
assert
(
cycles
>=
1.
)
assert
(
cycles
>=
1.
)
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
if
self
.
t_total
<=
0
:
return
1.
if
progress
<
self
.
warmup
:
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
progress
/
self
.
warmup
else
:
else
:
...
@@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
...
@@ -124,11 +122,10 @@ class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedul
"""
"""
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
assert
(
warmup
*
cycles
<
1.
)
assert
(
warmup
*
cycles
<
1.
)
super
(
WarmupCosineWithWarmupRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
*
cycles
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
warmup
=
warmup
*
cycles
if
warmup
>=
0
else
warmup
super
(
WarmupCosineWithWarmupRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
if
self
.
t_total
<=
0.
:
return
1.
progress
=
progress
*
self
.
cycles
%
1.
progress
=
progress
*
self
.
cycles
%
1.
if
progress
<
self
.
warmup
:
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
progress
/
self
.
warmup
...
@@ -174,7 +171,7 @@ class BertAdam(Optimizer):
...
@@ -174,7 +171,7 @@ class BertAdam(Optimizer):
lr: learning rate
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
rate schedule, -1 means constant learning rate
of 1. (no warmup regardless of warmup setting)
. Default: -1
schedule: schedule to use for the warmup (see above).
schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear'
Default: 'warmup_linear'
...
...
tests/optimization_test.py
View file @
91a073f8
...
@@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase):
...
@@ -51,9 +51,9 @@ class OptimizationTest(unittest.TestCase):
class
WarmupCosineWithRestartsTest
(
unittest
.
TestCase
):
class
WarmupCosineWithRestartsTest
(
unittest
.
TestCase
):
def
test_it
(
self
):
def
test_it
(
self
):
m
=
WarmupCosineWithWarmupRestartsSchedule
(
warmup
=
0.05
,
t_total
=
1
,
cycles
=
5
)
m
=
WarmupCosineWithWarmupRestartsSchedule
(
warmup
=
-
1
,
t_total
=
500
,
cycles
=
5
)
x
=
np
.
arange
(
0
,
1000
)
/
1000
x
=
np
.
arange
(
0
,
1000
)
y
=
[
m
.
get_lr
_
(
xe
)
for
xe
in
x
]
y
=
[
m
.
get_lr
(
xe
)
for
xe
in
x
]
plt
.
plot
(
y
)
plt
.
plot
(
y
)
plt
.
show
(
block
=
False
)
plt
.
show
(
block
=
False
)
y
=
np
.
asarray
(
y
)
y
=
np
.
asarray
(
y
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment