Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
46ef6460
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b2309cc6bffb5b0676b559c5932f1322a59b0811"
Commit
46ef6460
authored
Feb 27, 2019
by
lukovnikov
Browse files
added warning
parent
9bc3773c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
18 deletions
+25
-18
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+14
-15
pytorch_pretrained_bert/optimization_openai.py
pytorch_pretrained_bert/optimization_openai.py
+11
-3
No files found.
pytorch_pretrained_bert/optimization.py
View file @
46ef6460
...
@@ -35,17 +35,6 @@ def warmup_constant(x, warmup=0.002):
...
@@ -35,17 +35,6 @@ def warmup_constant(x, warmup=0.002):
return
x
/
warmup
return
x
/
warmup
return
1.0
return
1.0
class
Warmup_Linear_with_Warning
(
object
):
def
__init__
(
self
,
**
kw
):
super
(
Warmup_Linear_with_Warning
,
self
).
__init__
()
self
.
warned_at_x
=
-
1
def
__call__
(
self
,
x
,
warmup
=
0.002
):
if
x
>
1
and
x
>
self
.
warned_at_x
:
logger
.
warning
(
"Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly."
)
self
.
warned_at_x
=
x
return
warmup_linear
(
x
,
warmup
=
warmup
)
def
warmup_linear
(
x
,
warmup
=
0.002
):
def
warmup_linear
(
x
,
warmup
=
0.002
):
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
""" Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step.
After `t_total`-th training step, learning rate is zero. """
After `t_total`-th training step, learning rate is zero. """
...
@@ -54,9 +43,9 @@ def warmup_linear(x, warmup=0.002):
...
@@ -54,9 +43,9 @@ def warmup_linear(x, warmup=0.002):
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
SCHEDULES
=
{
SCHEDULES
=
{
'warmup_cosine'
:
warmup_cosine
,
'warmup_cosine'
:
warmup_cosine
,
'warmup_constant'
:
warmup_constant
,
'warmup_constant'
:
warmup_constant
,
'warmup_linear'
:
Warmup_Linear_with_Warning
(),
#
warmup_linear,
'warmup_linear'
:
warmup_linear
,
}
}
...
@@ -93,6 +82,8 @@ class BertAdam(Optimizer):
...
@@ -93,6 +82,8 @@ class BertAdam(Optimizer):
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
# warning for t_total exceeded
self
.
_warned_for_t_total_at_progress
=
-
1
if
schedule
==
"warmup_linear"
else
float
(
"inf"
)
def
get_lr
(
self
):
def
get_lr
(
self
):
lr
=
[]
lr
=
[]
...
@@ -163,7 +154,15 @@ class BertAdam(Optimizer):
...
@@ -163,7 +154,15 @@ class BertAdam(Optimizer):
if
group
[
't_total'
]
!=
-
1
:
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
progress
=
state
[
'step'
]
/
group
[
't_total'
]
if
progress
>
1.
and
progress
>
self
.
_warned_for_t_total_at_progress
:
logger
.
warning
(
"Training beyond specified 't_total' steps. Learning rate set to zero. "
"Please set 't_total' of {} correctly."
.
format
(
self
.
__class__
.
__name__
))
self
.
_warned_for_t_total_at_progress
=
progress
# end warning
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
else
:
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
...
...
pytorch_pretrained_bert/optimization_openai.py
View file @
46ef6460
...
@@ -40,8 +40,6 @@ def warmup_linear(x, warmup=0.002):
...
@@ -40,8 +40,6 @@ def warmup_linear(x, warmup=0.002):
After `t_total`-th training step, learning rate is zero. """
After `t_total`-th training step, learning rate is zero. """
if
x
<
warmup
:
if
x
<
warmup
:
return
x
/
warmup
return
x
/
warmup
if
x
>
1
:
logger
.
warning
(
"Training beyond specified 't_total' steps. Learning rate set to zero. Please set 't_total' of BertAdam correctly."
)
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
return
max
((
x
-
1.
)
/
(
warmup
-
1.
),
0
)
SCHEDULES
=
{
SCHEDULES
=
{
...
@@ -73,6 +71,8 @@ class OpenAIAdam(Optimizer):
...
@@ -73,6 +71,8 @@ class OpenAIAdam(Optimizer):
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
vector_l2
=
vector_l2
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
vector_l2
=
vector_l2
,
max_grad_norm
=
max_grad_norm
)
max_grad_norm
=
max_grad_norm
)
super
(
OpenAIAdam
,
self
).
__init__
(
params
,
defaults
)
super
(
OpenAIAdam
,
self
).
__init__
(
params
,
defaults
)
# warning for t_total exceeded
self
.
_warned_for_t_total_at_progress
=
-
1
if
schedule
==
"warmup_linear"
else
float
(
"inf"
)
def
get_lr
(
self
):
def
get_lr
(
self
):
lr
=
[]
lr
=
[]
...
@@ -137,7 +137,15 @@ class OpenAIAdam(Optimizer):
...
@@ -137,7 +137,15 @@ class OpenAIAdam(Optimizer):
if
group
[
't_total'
]
!=
-
1
:
if
group
[
't_total'
]
!=
-
1
:
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
schedule_fct
=
SCHEDULES
[
group
[
'schedule'
]]
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
state
[
'step'
]
/
group
[
't_total'
],
group
[
'warmup'
])
# warning for exceeding t_total (only active with warmup_linear
progress
=
state
[
'step'
]
/
group
[
't_total'
]
if
progress
>
1.
and
progress
>
self
.
_warned_for_t_total_at_progress
:
logger
.
warning
(
"Training beyond specified 't_total' steps. Learning rate set to zero. "
"Please set 't_total' of {} correctly."
.
format
(
self
.
__class__
.
__name__
))
self
.
_warned_for_t_total_at_progress
=
progress
# end warning
lr_scheduled
=
group
[
'lr'
]
*
schedule_fct
(
progress
,
group
[
'warmup'
])
else
:
else
:
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
=
group
[
'lr'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment