Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0cf88ff0
Commit
0cf88ff0
authored
Dec 13, 2018
by
thomwolf
Browse files
make examples work without apex
parent
52c53f39
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
14 deletions
+22
-14
examples/run_classifier.py
examples/run_classifier.py
+11
-7
examples/run_squad.py
examples/run_squad.py
+11
-7
No files found.
examples/run_classifier.py
View file @
0cf88ff0
...
...
@@ -36,13 +36,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
try
:
from
apex.optimizers
import
FP16_Optimizer
from
apex.optimizers
import
FusedAdam
from
apex.parallel
import
DistributedDataParallel
as
DDP
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this."
)
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
...
...
@@ -467,6 +460,11 @@ def main():
model
.
half
()
model
.
to
(
device
)
if
args
.
local_rank
!=
-
1
:
try
:
from
apex.parallel
import
DistributedDataParallel
as
DDP
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
model
=
DDP
(
model
)
elif
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
...
...
@@ -482,6 +480,12 @@ def main():
if
args
.
local_rank
!=
-
1
:
t_total
=
t_total
//
torch
.
distributed
.
get_world_size
()
if
args
.
fp16
:
try
:
from
apex.optimizers
import
FP16_Optimizer
from
apex.optimizers
import
FusedAdam
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
optimizer
=
FusedAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
bias_correction
=
False
,
...
...
examples/run_squad.py
View file @
0cf88ff0
...
...
@@ -39,13 +39,6 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
try
:
from
apex.optimizers
import
FP16_Optimizer
from
apex.optimizers
import
FusedAdam
from
apex.parallel
import
DistributedDataParallel
as
DDP
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this."
)
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
)
...
...
@@ -813,6 +806,11 @@ def main():
model
.
half
()
model
.
to
(
device
)
if
args
.
local_rank
!=
-
1
:
try
:
from
apex.parallel
import
DistributedDataParallel
as
DDP
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
model
=
DDP
(
model
)
elif
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
...
...
@@ -834,6 +832,12 @@ def main():
if
args
.
local_rank
!=
-
1
:
t_total
=
t_total
//
torch
.
distributed
.
get_world_size
()
if
args
.
fp16
:
try
:
from
apex.optimizers
import
FP16_Optimizer
from
apex.optimizers
import
FusedAdam
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
optimizer
=
FusedAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
bias_correction
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment