Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
81e1e248
".jenkins/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "c3d4bfe8224967baef120be1ce6a8a0cc82c12ea"
Commit
81e1e248
authored
Dec 10, 2018
by
Li Li
Browse files
Fix optimizer to work with horovod
parent
a2b6918a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
pytorch_pretrained_bert/optimization.py
pytorch_pretrained_bert/optimization.py
+3
-2
No files found.
pytorch_pretrained_bert/optimization.py
View file @
81e1e248
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
math
import
math
import
torch
import
torch
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
from
torch.nn.utils
import
clip_grad_norm_
def
warmup_cosine
(
x
,
warmup
=
0.002
):
def
warmup_cosine
(
x
,
warmup
=
0.002
):
...
@@ -55,10 +56,10 @@ class BertAdam(Optimizer):
...
@@ -55,10 +56,10 @@ class BertAdam(Optimizer):
weight_decay_rate: Weight decay. Default: 0.01
weight_decay_rate: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
"""
def
__init__
(
self
,
params
,
lr
,
warmup
=-
1
,
t_total
=-
1
,
schedule
=
'warmup_linear'
,
def
__init__
(
self
,
params
,
lr
=
required
,
warmup
=-
1
,
t_total
=-
1
,
schedule
=
'warmup_linear'
,
b1
=
0.9
,
b2
=
0.999
,
e
=
1e-6
,
weight_decay_rate
=
0.01
,
b1
=
0.9
,
b2
=
0.999
,
e
=
1e-6
,
weight_decay_rate
=
0.01
,
max_grad_norm
=
1.0
):
max_grad_norm
=
1.0
):
if
not
lr
>=
0.0
:
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
schedule
not
in
SCHEDULES
:
if
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment