Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
33599556
Commit
33599556
authored
Jun 18, 2019
by
thomwolf
Browse files
updating run_classif
parent
29b7b30e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
20 deletions
+3
-20
examples/run_classifier.py
examples/run_classifier.py
+3
-20
No files found.
examples/run_classifier.py
View file @
33599556
...
...
@@ -50,15 +50,6 @@ else:
logger
=
logging
.
getLogger
(
__name__
)
def
average_distributed_scalar
(
scalar
,
args
):
""" Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """
if
args
.
local_rank
==
-
1
:
return
scalar
scalar_t
=
torch
.
tensor
(
scalar
,
dtype
=
torch
.
float
,
device
=
args
.
device
)
/
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_reduce
(
scalar_t
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
return
scalar_t
.
item
()
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -368,7 +359,7 @@ def main():
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
output_dir
,
num_labels
=
num_labels
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
else
:
model
=
BertFor
QuestionAnswering
.
from_pretrained
(
args
.
bert_model
)
model
=
BertFor
SequenceClassification
.
from_pretrained
(
args
.
bert_model
)
model
.
to
(
device
)
...
...
@@ -453,10 +444,6 @@ def main():
preds
=
np
.
squeeze
(
preds
)
result
=
compute_metrics
(
task_name
,
preds
,
out_label_ids
)
if
args
.
local_rank
!=
-
1
:
# Average over distributed nodes if needed
result
=
{
key
:
average_distributed_scalar
(
value
,
args
)
for
key
,
value
in
result
.
items
()}
loss
=
tr_loss
/
global_step
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
...
...
@@ -530,10 +517,6 @@ def main():
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
result
=
compute_metrics
(
task_name
,
preds
,
out_label_ids
)
if
args
.
local_rank
!=
-
1
:
# Average over distributed nodes if needed
result
=
{
key
:
average_distributed_scalar
(
value
,
args
)
for
key
,
value
in
result
.
items
()}
loss
=
tr_loss
/
global_step
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment