Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7b9e5a54
Unverified
Commit
7b9e5a54
authored
Mar 06, 2019
by
Thomas Wolf
Committed by
GitHub
Mar 06, 2019
Browse files
Merge pull request #327 from lukovnikov/master
Issue#324: warmup linear fixes
parents
4784b04f
35410da7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
12 deletions
+49
-12
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+22
-5
pytorch_pretrained_bert/optimization_openai.py
pytorch_pretrained_bert/optimization_openai.py
+27
-7
No files found.
pytorch_pretrained_bert/optimization.py
View file @
7b9e5a54
...
...
@@ -19,6 +19,9 @@ import torch
from
torch.optim
import
Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
...
...
@@ -26,19 +29,23 @@ def warmup_cosine(x, warmup=0.002):
return
0.5
*
(
1.0
+
torch
.
cos
(
math
.
pi
*
x
))
def
warmup_constant
(
x
,
warmup
=
0.002
):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if
x
<
warmup
:
return
x
/
warmup
return
1.0
def
warmup_linear
(
x
,
warmup
=
0.002
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if
x
<
warmup
:
return
x
/
warmup
return
1.0
-
x
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
warmup_linear
,
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
warmup_linear
,
}
...
...
@@ -102,6 +109,8 @@ class BertAdam(Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -145,7 +154,15 @@ class BertAdam(Optimizer):
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
...
...
pytorch_pretrained_bert/optimization_openai.py
View file @
7b9e5a54
...
...
@@ -19,18 +19,28 @@ import torch
from
torch.optim
import
Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
s
=
1
if
x
<=
warmup
else
0
return
s
*
(
x
/
warmup
)
+
(
1
-
s
)
*
(
0.5
*
(
1
+
torch
.
cos
(
math
.
pi
*
x
)))
if
x
<
warmup
:
return
x
/
warmup
return
0.5
*
(
1.0
+
torch
.
cos
(
math
.
pi
*
x
))
def
warmup_constant
(
x
,
warmup
=
0.002
):
s
=
1
if
x
<=
warmup
else
0
return
s
*
(
x
/
warmup
)
+
(
1
-
s
)
*
1
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps.
Learning rate is 1. afterwards. """
if
x
<
warmup
:
return
x
/
warmup
return
1.0
def
warmup_linear
(
x
,
warmup
=
0.002
):
s
=
1
if
x
<=
warmup
else
0
return
(
s
*
(
x
/
warmup
)
+
(
1
-
s
))
*
(
1
-
x
)
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if
x
<
warmup
:
return
x
/
warmup
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
...
...
@@ -88,6 +98,8 @@ class OpenAIAdam(Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -125,7 +137,15 @@ class OpenAIAdam(Optimizer):
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment