Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
29b7b30e
Commit
29b7b30e
authored
Jun 18, 2019
by
thomwolf
Browse files
updating evaluation on a single gpu
parent
7d2001aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
14 deletions
+6
-14
examples/run_classifier.py
examples/run_classifier.py
+6
-14
No files found.
examples/run_classifier.py
View file @
29b7b30e
...
@@ -306,10 +306,10 @@ def main():
...
@@ -306,10 +306,10 @@ def main():
logger
.
info
(
" Num steps = %d"
,
num_train_optimization_steps
)
logger
.
info
(
" Num steps = %d"
,
num_train_optimization_steps
)
model
.
train
()
model
.
train
()
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
):
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]
):
tr_loss
=
0
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
)):
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]
)):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
input_ids
,
input_mask
,
segment_ids
,
label_ids
=
batch
...
@@ -367,21 +367,13 @@ def main():
...
@@ -367,21 +367,13 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
# Load a trained model and vocabulary that you have fine-tuned
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
output_dir
,
num_labels
=
num_labels
)
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
output_dir
,
num_labels
=
num_labels
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
else
:
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
# Distributed/fp16/parallel settings (optional)
model
.
to
(
device
)
model
.
to
(
device
)
if
args
.
fp16
:
model
.
half
()
if
args
.
local_rank
!=
-
1
:
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
local_rank
],
output_device
=
args
.
local_rank
,
find_unused_parameters
=
True
)
elif
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
### Evaluation
### Evaluation
if
args
.
do_eval
:
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
)
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
cached_eval_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
'dev_{0}_{1}_{2}'
.
format
(
cached_eval_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
'dev_{0}_{1}_{2}'
.
format
(
list
(
filter
(
None
,
args
.
bert_model
.
split
(
'/'
))).
pop
(),
list
(
filter
(
None
,
args
.
bert_model
.
split
(
'/'
))).
pop
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment