Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7388c83b
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0d1dad6d5323cf627cb8d7ddd428856ab8475f6b"
Commit
7388c83b
authored
Jun 18, 2019
by
thomwolf
Browse files
update run_classifier for distributed eval
parent
97277232
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
6 deletions
+41
-6
examples/run_classifier.py
examples/run_classifier.py
+41
-6
No files found.
examples/run_classifier.py
View file @
7388c83b
...
...
@@ -50,6 +50,15 @@ else:
logger
=
logging
.
getLogger
(
__name__
)
def
average_distributed_scalar
(
scalar
,
args
):
""" Average a scalar over the nodes if we are in distributed training. We use this for distributed evaluation. """
if
args
.
local_rank
==
-
1
:
return
scalar
scalar_t
=
torch
.
tensor
(
scalar
,
dtype
=
torch
.
float
,
device
=
args
.
device
)
/
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_reduce
(
scalar_t
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
return
scalar_t
.
item
()
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -158,6 +167,7 @@ def main():
n_gpu
=
1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
)
args
.
device
=
device
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
...
...
@@ -337,6 +347,8 @@ def main():
tb_writer
.
add_scalar
(
'lr'
,
optimizer
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
loss
.
item
(),
global_step
)
### Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
### Example:
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
# Save a trained model, configuration and tokenizer
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
...
...
@@ -352,11 +364,21 @@ def main():
# Load a trained model and vocabulary that you have fine-tuned
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
output_dir
,
num_labels
=
num_labels
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
else
:
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
num_labels
=
num_labels
)
model
.
to
(
device
)
if
args
.
do_eval
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
# Distributed/fp16/parallel settings (optional)
model
.
to
(
device
)
if
args
.
fp16
:
model
.
half
()
if
args
.
local_rank
!=
-
1
:
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
args
.
local_rank
],
output_device
=
args
.
local_rank
,
find_unused_parameters
=
True
)
elif
n_gpu
>
1
:
model
=
torch
.
nn
.
DataParallel
(
model
)
### Evaluation
if
args
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
output_mode
)
...
...
@@ -374,7 +396,10 @@ def main():
eval_data
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_label_ids
)
# Run prediction for full data
eval_sampler
=
SequentialSampler
(
eval_data
)
if
args
.
local_rank
==
-
1
:
eval_sampler
=
SequentialSampler
(
eval_data
)
else
:
eval_sampler
=
DistributedSampler
(
eval_data
)
# Note that this sampler samples randomly
eval_dataloader
=
DataLoader
(
eval_data
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
model
.
eval
()
...
...
@@ -398,7 +423,7 @@ def main():
elif
output_mode
==
"regression"
:
loss_fct
=
MSELoss
()
tmp_eval_loss
=
loss_fct
(
logits
.
view
(
-
1
),
label_ids
.
view
(
-
1
))
eval_loss
+=
tmp_eval_loss
.
mean
().
item
()
nb_eval_steps
+=
1
if
len
(
preds
)
==
0
:
...
...
@@ -414,6 +439,11 @@ def main():
elif
output_mode
==
"regression"
:
preds
=
np
.
squeeze
(
preds
)
result
=
compute_metrics
(
task_name
,
preds
,
all_label_ids
.
numpy
())
if
args
.
local_rank
!=
-
1
:
# Average over distributed nodes if needed
result
=
{
key
:
average_distributed_scalar
(
value
,
args
)
for
key
,
value
in
result
.
items
()}
loss
=
tr_loss
/
global_step
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
...
...
@@ -482,6 +512,11 @@ def main():
preds
=
preds
[
0
]
preds
=
np
.
argmax
(
preds
,
axis
=
1
)
result
=
compute_metrics
(
task_name
,
preds
,
all_label_ids
.
numpy
())
if
args
.
local_rank
!=
-
1
:
# Average over distributed nodes if needed
result
=
{
key
:
average_distributed_scalar
(
value
,
args
)
for
key
,
value
in
result
.
items
()}
loss
=
tr_loss
/
global_step
if
args
.
do_train
else
None
result
[
'eval_loss'
]
=
eval_loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment