Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
90a41dbe
Commit
90a41dbe
authored
Mar 09, 2019
by
lukovnikov
Browse files
BertAdam schedule objects
parent
88874f6c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
33 deletions
+17
-33
pytorch_pretrained_bert/__init__.py
pytorch_pretrained_bert/__init__.py
+1
-1
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+16
-32
No files found.
pytorch_pretrained_bert/__init__.py
View file @
90a41dbe
...
@@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
...
@@ -18,7 +18,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2Model,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
GPT2LMHeadModel
,
GPT2DoubleHeadsModel
,
load_tf_weights_in_gpt2
)
load_tf_weights_in_gpt2
)
from
.optimization
import
BertAdam
from
.optimization
import
*
from
.optimization_openai
import
OpenAIAdam
from
.optimization_openai
import
OpenAIAdam
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
from
.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
pytorch_pretrained_bert/optimization.py
View file @
90a41dbe
...
@@ -24,6 +24,9 @@ import logging
...
@@ -24,6 +24,9 @@ import logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"LRSchedule"
,
"WarmupLinearSchedule"
,
"WarmupConstantSchedule"
,
"WarmupCosineSchedule"
,
"BertAdam"
]
class
LRSchedule
(
object
):
class
LRSchedule
(
object
):
warn_t_total
=
False
warn_t_total
=
False
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
...
@@ -83,32 +86,7 @@ class WarmupLinearSchedule(LRSchedule):
...
@@ -83,32 +86,7 @@ class WarmupLinearSchedule(LRSchedule):
if
progress
<
self
.
warmup
:
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
progress
/
self
.
warmup
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0
)
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0
)
#
#
# def warmup_cosine(x, warmup=0.002):
# if x < warmup:
# return x/warmup
# return 0.5 * (1.0 + torch.cos(math.pi * x))
#
# def warmup_constant(x, warmup=0.002):
# """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps.
# Learning rate is 1. afterwards. """
# if x < warmup:
# return x/warmup
# return 1.0
#
# def warmup_linear(x, warmup=0.002):
# """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
# After `t_total`-th training step, learning rate is zero. """
# if x < warmup:
# return x/warmup
# return max((x-1.)/(warmup-1.), 0)
#
# SCHEDULES = {
# 'warmup_cosine': warmup_cosine,
# 'warmup_constant': warmup_constant,
# 'warmup_linear': warmup_linear,
# }
SCHEDULES
=
{
SCHEDULES
=
{
None
:
LRSchedule
,
None
:
LRSchedule
,
...
@@ -126,7 +104,9 @@ class BertAdam(Optimizer):
...
@@ -126,7 +104,9 @@ class BertAdam(Optimizer):
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
schedule: schedule to use for the warmup (see above).
Can be 'warmup_linear', 'warmup_constant', 'warmup_cosine', or a LRSchedule object.
Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
e: Adams epsilon. Default: 1e-6
...
@@ -147,9 +127,13 @@ class BertAdam(Optimizer):
...
@@ -147,9 +127,13 @@ class BertAdam(Optimizer):
if
not
e
>=
0.0
:
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
# initialize schedule object
# initialize schedule object
if
not
isinstance
(
schedule
,
LRSchedule
):
schedule_type
=
SCHEDULES
[
schedule
]
schedule_type
=
SCHEDULES
[
schedule
]
sched
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
schedule
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
sched
,
else
:
if
warmup
!=
-
1
or
t_total
!=
-
1
:
logger
.
warning
(
"Non-default warmup and t_total are ineffective when LRSchedule object is provided."
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
...
@@ -163,7 +147,7 @@ class BertAdam(Optimizer):
...
@@ -163,7 +147,7 @@ class BertAdam(Optimizer):
return
[
0
]
return
[
0
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
lr_scheduled
*=
group
[
'schedule'
]
.
get_lr
(
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
lr
.
append
(
lr_scheduled
)
return
lr
return
lr
...
@@ -221,7 +205,7 @@ class BertAdam(Optimizer):
...
@@ -221,7 +205,7 @@ class BertAdam(Optimizer):
update
+=
group
[
'weight_decay'
]
*
p
.
data
update
+=
group
[
'weight_decay'
]
*
p
.
data
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
](
state
[
'step'
])
lr_scheduled
*=
group
[
'schedule'
]
.
get_lr
(
state
[
'step'
])
update_with_lr
=
lr_scheduled
*
update
update_with_lr
=
lr_scheduled
*
update
p
.
data
.
add_
(
-
update_with_lr
)
p
.
data
.
add_
(
-
update_with_lr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment