Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
88874f6c
Commit
88874f6c
authored
Mar 08, 2019
by
lukovnikov
Browse files
BertAdam schedule objects
parent
7cc35c31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
44 deletions
+101
-44
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+101
-44
No files found.
pytorch_pretrained_bert/optimization.py
View file @
88874f6c
...
@@ -23,29 +23,99 @@ import logging
...
@@ -23,29 +23,99 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_cosine
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
class
LRSchedule
(
object
):
return
x
/
warmup
warn_t_total
=
False
return
0.5
*
(
1.0
+
torch
.
cos
(
math
.
pi
*
x
))
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
super
(
LRSchedule
,
self
).
__init__
(
**
kw
)
def
warmup_constant
(
x
,
warmup
=
0.002
):
self
.
warmup
,
self
.
t_total
=
warmup
,
t_total
""" Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
if
t_total
<=
0
:
Learning rate is 1. afterwards. """
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
if
x
<
warmup
:
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
return
x
/
warmup
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
return
1.0
self
.
warned_for_t_total_at_progress
=
-
1
def
warmup_linear
(
x
,
warmup
=
0.002
):
def
get_lr
(
self
,
step
,
nowarn
=
False
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
progress
=
step
/
self
.
t_total
After `t_total`-th training step, learning rate is zero. """
ret
=
self
.
get_lr_
(
progress
)
if
x
<
warmup
:
# warning for exceeding t_total (only active with warmup_linear
return
x
/
warmup
if
not
nowarn
and
self
.
warn_t_total
and
progress
>
1.
and
progress
>
self
.
warned_for_t_total_at_progress
:
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
logger
.
warning
(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.
format
(
ret
,
self
.
__class__
.
__name__
))
self
.
warned_for_t_total_at_progress
=
progress
# end warning
return
ret
def
get_lr_
(
self
,
step
):
return
1.
# raise NotImplemented("use subclass")
class
WarmupCosineSchedule
(
LRSchedule
):
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
.
5
,
**
kw
):
super
(
WarmupCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
**
kw
)
self
.
cycles
=
cycles
def
get_lr_
(
self
,
progress
):
""" get learning rate multiplier """
if
self
.
t_total
<=
0
:
return
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
return
0.5
*
(
1.
+
torch
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
class
WarmupConstantSchedule
(
LRSchedule
):
warn_t_total
=
False
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
1.
class
WarmupLinearSchedule
(
LRSchedule
):
warn_t_total
=
True
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0
)
#
#
# def warmup_cosine(x, warmup=0.002):
# if x < warmup:
# return x/warmup
# return 0.5 * (1.0 + torch.cos(math.pi * x))
#
# def warmup_constant(x, warmup=0.002):
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
# Learning rate is 1. afterwards. """
# if x < warmup:
# return x/warmup
# return 1.0
#
# def warmup_linear(x, warmup=0.002):
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
# After `t_total`-th training step, learning rate is zero. """
# if x < warmup:
# return x/warmup
# return max((x-1.)/(warmup-1.), 0)
#
# SCHEDULES = {
# 'warmup_cosine': warmup_cosine,
# 'warmup_constant': warmup_constant,
# 'warmup_linear': warmup_linear,
# }
SCHEDULES
=
{
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
None
:
LRSchedule
,
'warmup_constant'
:
warmup_constant
,
"none"
:
LRSchedule
,
'warmup_linear'
:
warmup_linear
,
"warmup_cosine"
:
WarmupCosineSchedule
,
"warmup_constant"
:
WarmupConstantSchedule
,
"warmup_linear"
:
WarmupLinearSchedule
}
}
...
@@ -70,15 +140,16 @@ class BertAdam(Optimizer):
...
@@ -70,15 +140,16 @@ class BertAdam(Optimizer):
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
schedule
not
in
SCHEDULES
:
if
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
if
not
0.0
<=
b1
<
1.0
:
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
if
not
e
>=
0.0
:
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
warmup
=
warmup
,
t_total
=
t_total
,
# initialize schedule object
schedule_type
=
SCHEDULES
[
schedule
]
sched
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
sched
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
...
@@ -90,11 +161,10 @@ class BertAdam(Optimizer):
...
@@ -90,11 +161,10 @@ class BertAdam(Optimizer):
state
=
self
.
state
[
p
]
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
if
len
(
state
)
==
0
:
return
[
0
]
return
[
0
]
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
else
:
lr_scheduled
=
group
[
'lr'
]
lr
.
append
(
lr_scheduled
)
lr
.
append
(
lr_scheduled
)
return
lr
return
lr
...
@@ -109,8 +179,6 @@ class BertAdam(Optimizer):
...
@@ -109,8 +179,6 @@ class BertAdam(Optimizer):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
closure
()
loss
=
closure
()
warned_for_t_total
=
False
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
if
p
.
grad
is
None
:
...
@@ -152,19 +220,8 @@ class BertAdam(Optimizer):
...
@@ -152,19 +220,8 @@ class BertAdam(Optimizer):
if
group
[
'weight_decay'
]
>
0.0
:
if
group
[
'weight_decay'
]
>
0.0
:
update
+=
group
[
'weight_decay'
]
*
p
.
data
update
+=
group
[
'weight_decay'
]
*
p
.
data
if
group
[
't_total'
]
!=
-
1
:
lr_scheduled
=
group
[
'lr'
]
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
progress
=
state
[
'step'
]
/
group
[
't_total'
]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
if
group
[
'schedule'
]
==
"warmup_linear"
and
progress
>
1.
and
not
warned_for_t_total
:
logger
.
warning
(
"Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. "
"Please set 't_total' of {} correctly."
.
format
(
group
[
'schedule'
],
lr_scheduled
,
self
.
__class__
.
__name__
))
warned_for_t_total
=
True
# end warning
else
:
lr_scheduled
=
group
[
'lr'
]
update_with_lr
=
lr_scheduled
*
update
update_with_lr
=
lr_scheduled
*
update
p
.
data
.
add_
(
-
update_with_lr
)
p
.
data
.
add_
(
-
update_with_lr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment