Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
262a9992
Commit
262a9992
authored
Mar 18, 2019
by
lukovnikov
Browse files
class weights
parent
b6c1cae6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
4 deletions
+32
-4
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+18
-3
tests/optimization_test.py
tests/optimization_test.py
+14
-1
No files found.
pytorch_pretrained_bert/optimization.py
View file @
262a9992
...
@@ -24,7 +24,8 @@ import logging
...
@@ -24,7 +24,8 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"LRSchedule"
,
"WarmupLinearSchedule"
,
"WarmupConstantSchedule"
,
"WarmupCosineSchedule"
,
"BertAdam"
,
"WarmupCosineWithRestartsSchedule"
]
__all__
=
[
"LRSchedule"
,
"WarmupLinearSchedule"
,
"WarmupConstantSchedule"
,
"WarmupCosineSchedule"
,
"BertAdam"
,
"WarmupMultiCosineSchedule"
,
"WarmupCosineWithRestartsSchedule"
]
class
LRSchedule
(
object
):
class
LRSchedule
(
object
):
...
@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
...
@@ -72,10 +73,11 @@ class WarmupCosineSchedule(LRSchedule):
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
class
WarmupCosine
WithRestarts
Schedule
(
WarmupCosineSchedule
):
class
Warmup
Multi
CosineSchedule
(
WarmupCosineSchedule
):
warn_t_total
=
True
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
super
(
WarmupCosineWithRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
super
(
WarmupMultiCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
assert
(
cycles
>=
1.
)
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
if
self
.
t_total
<=
0
:
if
self
.
t_total
<=
0
:
...
@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
...
@@ -88,6 +90,19 @@ class WarmupCosineWithRestartsSchedule(WarmupCosineSchedule):
return
ret
return
ret
class
WarmupCosineWithRestartsSchedule
(
WarmupMultiCosineSchedule
):
def
get_lr_
(
self
,
progress
):
if
self
.
t_total
<=
0.
:
return
1.
progress
=
progress
*
self
.
cycles
%
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
progress
))
return
ret
class
WarmupConstantSchedule
(
LRSchedule
):
class
WarmupConstantSchedule
(
LRSchedule
):
warn_t_total
=
False
warn_t_total
=
False
def
get_lr_
(
self
,
progress
):
def
get_lr_
(
self
,
progress
):
...
...
tests/optimization_test.py
View file @
262a9992
...
@@ -20,7 +20,9 @@ import unittest
...
@@ -20,7 +20,9 @@ import unittest
import
torch
import
torch
from
pytorch_pretrained_bert
import
BertAdam
from
pytorch_pretrained_bert
import
BertAdam
,
WarmupCosineWithRestartsSchedule
from
matplotlib
import
pyplot
as
plt
import
numpy
as
np
class
OptimizationTest
(
unittest
.
TestCase
):
class
OptimizationTest
(
unittest
.
TestCase
):
...
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
...
@@ -46,5 +48,16 @@ class OptimizationTest(unittest.TestCase):
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
class
WarmupCosineWithRestartsTest
(
unittest
.
TestCase
):
def
test_it
(
self
):
m
=
WarmupCosineWithRestartsSchedule
(
warmup
=
0.2
,
t_total
=
1
,
cycles
=
3
)
x
=
np
.
arange
(
0
,
1000
)
/
1000
y
=
[
m
.
get_lr_
(
xe
)
for
xe
in
x
]
plt
.
plot
(
y
)
plt
.
show
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment