Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
262a9992
Commit
262a9992
authored
Mar 18, 2019
by
lukovnikov
Browse files
class weights
parent
b6c1cae6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
4 deletions
+32
-4
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+18
-3
tests/optimization_test.py
tests/optimization_test.py
+14
-1
No files found.
pytorch_pretrained_bert/optimization.py
View file @
262a9992
...
@@ -24,7 +24,8 @@ import logging
...
@@ -24,7 +24,8 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"LRSchedule"
,
"WarmupLinearSchedule"
,
"WarmupConstantSchedule"
,
"WarmupCosineSchedule"
,
"BertAdam"
,
"WarmupCosineWithRestartsSchedule"
]
__all__
=
[
"LRSchedule"
,
"WarmupLinearSchedule"
,
"WarmupConstantSchedule"
,
"WarmupCosineSchedule"
,
"BertAdam"
,
"WarmupMultiCosineSchedule"
,
"WarmupCosineWithRestartsSchedule"
]
class
LRSchedule
(
object
):
class
LRSchedule
(
object
):
...
@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
...
@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
class
WarmupCosine
WithRestarts
Schedule
(
WarmupCosineSchedule
):
class
Warmup
Multi
CosineSchedule
(
WarmupCosineSchedule
):
warn_t_total
=
True
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
super
(
WarmupCosineWithRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
super
(
WarmupMultiCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
assert
(
cycles
>=
1.
)
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
if
self
.
t_total
<=
0
:
if
self
.
t_total
<=
0
:
...
@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
...
@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
return
ret
return
ret
class
WarmupCosineWithRestartsSchedule
(
WarmupMultiCosineSchedule
):
def
get_lr_
(
self
,
progress
):
if
self
.
t_total
<=
0.
:
return
1.
progress
=
progress
*
self
.
cycles
%
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
progress
))
return
ret
class
WarmupConstantSchedule
(
LRSchedule
):
class
WarmupConstantSchedule
(
LRSchedule
):
warn_t_total
=
False
warn_t_total
=
False
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
...
...
tests/optimization_test.py
View file @
262a9992
...
@@ -20,7 +20,9 @@ import unittest
...
@@ -20,7 +20,9 @@ import unittest
import
torch
import
torch
from
pytorch_pretrained_bert
import
BertAdam
from
pytorch_pretrained_bert
import
BertAdam
,
WarmupCosineWithRestartsSchedule
from
matplotlib
import
pyplot
as
plt
import
numpy
as
np
class
OptimizationTest
(
unittest
.
TestCase
):
class
OptimizationTest
(
unittest
.
TestCase
):
...
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
...
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
class
WarmupCosineWithRestartsTest
(
unittest
.
TestCase
):
def
test_it
(
self
):
m
=
WarmupCosineWithRestartsSchedule
(
warmup
=
0.2
,
t_total
=
1
,
cycles
=
3
)
x
=
np
.
arange
(
0
,
1000
)
/
1000
y
=
[
m
.
get_lr_
(
xe
)
for
xe
in
x
]
plt
.
plot
(
y
)
plt
.
show
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment