Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88874f6c
Commit
88874f6c
authored
Mar 08, 2019
by
lukovnikov
Browse files
BertAdam schedule objects
parent
7cc35c31
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
44 deletions
+101
-44
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+101
-44
No files found.
pytorch_pretrained_bert/optimization.py
View file @
88874f6c
...
...
@@ -23,29 +23,99 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
0.5
*
(
1.0
+
torch
.
cos
(
math
.
pi
*
x
))
def
warmup_constant
(
x
,
warmup
=
0.002
):
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
Learning rate is 1. afterwards. """
if
x
<
warmup
:
return
x
/
warmup
return
1.0
def
warmup_linear
(
x
,
warmup
=
0.002
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
if
x
<
warmup
:
return
x
/
warmup
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
class
LRSchedule
(
object
):
warn_t_total
=
False
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
super
(
LRSchedule
,
self
).
__init__
(
**
kw
)
self
.
warmup
,
self
.
t_total
=
warmup
,
t_total
if
t_total
<=
0
:
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
self
.
warned_for_t_total_at_progress
=
-
1
def
get_lr
(
self
,
step
,
nowarn
=
False
):
progress
=
step
/
self
.
t_total
ret
=
self
.
get_lr_
(
progress
)
# warning for exceeding t_total (only active with warmup_linear
if
not
nowarn
and
self
.
warn_t_total
and
progress
>
1.
and
progress
>
self
.
warned_for_t_total_at_progress
:
logger
.
warning
(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.
format
(
ret
,
self
.
__class__
.
__name__
))
self
.
warned_for_t_total_at_progress
=
progress
# end warning
return
ret
def
get_lr_
(
self
,
step
):
return
1.
# raise NotImplemented("use subclass")
class
WarmupCosineSchedule
(
LRSchedule
):
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
.
5
,
**
kw
):
super
(
WarmupCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
**
kw
)
self
.
cycles
=
cycles
def
get_lr_
(
self
,
progress
):
""" get learning rate multiplier """
if
self
.
t_total
<=
0
:
return
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
return
0.5
*
(
1.
+
torch
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
class
WarmupConstantSchedule
(
LRSchedule
):
warn_t_total
=
False
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
1.
class
WarmupLinearSchedule
(
LRSchedule
):
warn_t_total
=
True
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0
)
#
#
# def warmup_cosine(x, warmup=0.002):
# if x < warmup:
# return x/warmup
# return 0.5 * (1.0 + torch.cos(math.pi * x))
#
# def warmup_constant(x, warmup=0.002):
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
# Learning rate is 1. afterwards. """
# if x < warmup:
# return x/warmup
# return 1.0
#
# def warmup_linear(x, warmup=0.002):
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
# After `t_total`-th training step, learning rate is zero. """
# if x < warmup:
# return x/warmup
# return max((x-1.)/(warmup-1.), 0)
#
# SCHEDULES = {
# 'warmup_cosine': warmup_cosine,
# 'warmup_constant': warmup_constant,
# 'warmup_linear': warmup_linear,
# }
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
warmup_linear
,
None
:
LRSchedule
,
"none"
:
LRSchedule
,
"warmup_cosine"
:
WarmupCosineSchedule
,
"warmup_constant"
:
WarmupConstantSchedule
,
"warmup_linear"
:
WarmupLinearSchedule
}
...
...
@@ -70,15 +140,16 @@ class BertAdam(Optimizer):
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
warmup
=
warmup
,
t_total
=
t_total
,
# initialize schedule object
schedule_type
=
SCHEDULES
[
schedule
]
sched
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
sched
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
...
...
@@ -90,11 +161,10 @@ class BertAdam(Optimizer):
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
return
[
0
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
return
lr
...
...
@@ -109,8 +179,6 @@ class BertAdam(Optimizer):
if
closure
is
not
None
:
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
...
...
@@ -152,19 +220,8 @@ class BertAdam(Optimizer):
if
group
[
'weight_decay'
]
>
0.0
:
update
+=
group
[
'weight_decay'
]
*
p
.
data
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
update_with_lr
=
lr_scheduled
*
update
p
.
data
.
add_
(
-
update_with_lr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment