Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88874f6c
Commit
88874f6c
authored
Mar 08, 2019
by
lukovnikov
Browse files
BertAdam schedule objects
parent
7cc35c31
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
44 deletions
+101
-44
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+101
-44
No files found.
pytorch_pretrained_bert/optimization.py
View file @
88874f6c
...
@@ -23,29 +23,99 @@ import logging
...
@@ -23,29 +23,99 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
class
LRSchedule
(
object
):
return
x
/
warmup
warn_t_total
=
False
return
0.5
*
(
1.0
+
torch
.
cos
(
math
.
pi
*
x
))
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
super
(
LRSchedule
,
self
).
__init__
(
**
kw
)
def
warmup_constant
(
x
,
warmup
=
0.002
):
self
.
warmup
,
self
.
t_total
=
warmup
,
t_total
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
if
t_total
<=
0
:
Learning rate is 1. afterwards. """
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
if
x
<
warmup
:
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
return
x
/
warmup
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
return
1.0
self
.
warned_for_t_total_at_progress
=
-
1
def
warmup_linear
(
x
,
warmup
=
0.002
):
def
get_lr
(
self
,
step
,
nowarn
=
False
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
progress
=
step
/
self
.
t_total
After `t_total`-th training step, learning rate is zero. """
ret
=
self
.
get_lr_
(
progress
)
if
x
<
warmup
:
# warning for exceeding t_total (only active with warmup_linear
return
x
/
warmup
if
not
nowarn
and
self
.
warn_t_total
and
progress
>
1.
and
progress
>
self
.
warned_for_t_total_at_progress
:
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
logger
.
warning
(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.
format
(
ret
,
self
.
__class__
.
__name__
))
self
.
warned_for_t_total_at_progress
=
progress
# end warning
return
ret
def
get_lr_
(
self
,
step
):
return
1.
# raise NotImplemented("use subclass")
class
WarmupCosineSchedule
(
LRSchedule
):
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
.
5
,
**
kw
):
super
(
WarmupCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
**
kw
)
self
.
cycles
=
cycles
def
get_lr_
(
self
,
progress
):
""" get learning rate multiplier """
if
self
.
t_total
<=
0
:
return
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
return
0.5
*
(
1.
+
torch
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
class
WarmupConstantSchedule
(
LRSchedule
):
warn_t_total
=
False
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
1.
class
WarmupLinearSchedule
(
LRSchedule
):
warn_t_total
=
True
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0
)
#
#
# def warmup_cosine(x, warmup=0.002):
# if x < warmup:
# return x/warmup
# return 0.5 * (1.0 + torch.cos(math.pi * x))
#
# def warmup_constant(x, warmup=0.002):
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
# Learning rate is 1. afterwards. """
# if x < warmup:
# return x/warmup
# return 1.0
#
# def warmup_linear(x, warmup=0.002):
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
# After `t_total`-th training step, learning rate is zero. """
# if x < warmup:
# return x/warmup
# return max((x-1.)/(warmup-1.), 0)
#
# SCHEDULES = {
# 'warmup_cosine': warmup_cosine,
# 'warmup_constant': warmup_constant,
# 'warmup_linear': warmup_linear,
# }
SCHEDULES
=
{
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
None
:
LRSchedule
,
'warmup_constant'
:
warmup_constant
,
"none"
:
LRSchedule
,
'warmup_linear'
:
warmup_linear
,
"warmup_cosine"
:
WarmupCosineSchedule
,
"warmup_constant"
:
WarmupConstantSchedule
,
"warmup_linear"
:
WarmupLinearSchedule
}
}
...
@@ -70,15 +140,16 @@ class BertAdam(Optimizer):
...
@@ -70,15 +140,16 @@ class BertAdam(Optimizer):
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
schedule
not
in
SCHEDULES
:
if
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
if
not
0.0
<=
b1
<
1.0
:
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
if
not
e
>=
0.0
:
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
warmup
=
warmup
,
t_total
=
t_total
,
# initialize schedule object
schedule_type
=
SCHEDULES
[
schedule
]
sched
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
sched
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
...
@@ -90,11 +161,10 @@ class BertAdam(Optimizer):
...
@@ -90,11 +161,10 @@ class BertAdam(Optimizer):
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
if
len
(
state
)
==
0
:
return
[
0
]
return
[
0
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
lr
.
append
(
lr_scheduled
)
return
lr
return
lr
...
@@ -109,8 +179,6 @@ class BertAdam(Optimizer):
...
@@ -109,8 +179,6 @@ class BertAdam(Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
...
@@ -152,19 +220,8 @@ class BertAdam(Optimizer):
...
@@ -152,19 +220,8 @@ class BertAdam(Optimizer):
if
group
[
'weight_decay'
]
>
0.0
:
if
group
[
'weight_decay'
]
>
0.0
:
update
+=
group
[
'weight_decay'
]
*
p
.
data
update
+=
group
[
'weight_decay'
]
*
p
.
data
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
update_with_lr
=
lr_scheduled
*
update
update_with_lr
=
lr_scheduled
*
update
p
.
data
.
add_
(
-
update_with_lr
)
p
.
data
.
add_
(
-
update_with_lr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment