Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
164c794e
Commit
164c794e
authored
Jan 08, 2020
by
Lysandre
Committed by
Lysandre Debut
Jan 10, 2020
Browse files
New SQuAD API for distillation script
parent
801f2ac8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
68 deletions
+59
-68
examples/distillation/run_squad_w_distillation.py
examples/distillation/run_squad_w_distillation.py
+59
-68
No files found.
examples/distillation/run_squad_w_distillation.py
View file @
164c794e
...
...
@@ -15,7 +15,6 @@
# limitations under the License.
""" This is the exact same script as `examples/run_squad.py` (as of 2019, October 4th) with an additional and optional step of distillation."""
import
argparse
import
glob
import
logging
...
...
@@ -26,7 +25,7 @@ import numpy as np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
from
torch.utils.data
import
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
...
...
@@ -46,22 +45,14 @@ from transformers import (
XLNetForQuestionAnswering
,
XLNetTokenizer
,
get_linear_schedule_with_warmup
,
squad_convert_examples_to_features
,
)
from
..utils_squad
import
(
RawResult
,
RawResultExtended
,
convert_examples_to_features
,
read_squad_examples
,
write_predictions
,
write_predictions_extended
,
from
transformers.data.metrics.squad_metrics
import
(
compute_predictions_log_probs
,
compute_predictions_logits
,
squad_evaluate
,
)
# The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file)
from
..utils_squad_evaluate
import
EVAL_OPTS
from
..utils_squad_evaluate
import
main
as
evaluate_on_squad
from
transformers.data.processors.squad
import
SquadResult
,
SquadV1Processor
,
SquadV2Processor
try
:
...
...
@@ -69,7 +60,6 @@ try:
except
ImportError
:
from
tensorboardX
import
SummaryWriter
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
(
...
...
@@ -294,20 +284,31 @@ def evaluate(args, model, tokenizer, prefix=""):
for
i
,
example_index
in
enumerate
(
example_indices
):
eval_feature
=
features
[
example_index
.
item
()]
unique_id
=
int
(
eval_feature
.
unique_id
)
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
# XLNet uses a more complex post-processing procedure
result
=
RawResultExtended
(
unique_id
=
unique_id
,
start_top_log_probs
=
to_list
(
outputs
[
0
][
i
]),
start_top_index
=
to_list
(
outputs
[
1
][
i
]),
end_top_log_probs
=
to_list
(
outputs
[
2
][
i
]),
end_top_index
=
to_list
(
outputs
[
3
][
i
]),
cls_logits
=
to_list
(
outputs
[
4
][
i
]),
output
=
[
to_list
(
output
[
i
])
for
output
in
outputs
]
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
# models only use two.
if
len
(
output
)
>=
5
:
start_logits
=
output
[
0
]
start_top_index
=
output
[
1
]
end_logits
=
output
[
2
]
end_top_index
=
output
[
3
]
cls_logits
=
output
[
4
]
result
=
SquadResult
(
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
start_top_index
,
end_top_index
=
end_top_index
,
cls_logits
=
cls_logits
,
)
else
:
result
=
RawResult
(
unique_id
=
unique_id
,
start_logits
=
to_list
(
outputs
[
0
][
i
]),
end_logits
=
to_list
(
outputs
[
1
][
i
]
)
)
start_logits
,
end_logits
=
output
result
=
SquadResult
(
unique_id
,
start_logits
,
end_logits
)
all_results
.
append
(
result
)
# Compute predictions
...
...
@@ -320,7 +321,7 @@ def evaluate(args, model, tokenizer, prefix=""):
if
args
.
model_type
in
[
"xlnet"
,
"xlm"
]:
# XLNet uses a more complex post-processing procedure
wri
te_predictions_
extended
(
predictions
=
compu
te_predictions_
log_probs
(
examples
,
features
,
all_results
,
...
...
@@ -337,7 +338,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args
.
verbose_logging
,
)
else
:
wri
te_predictions
(
predictions
=
compu
te_predictions
_logits
(
examples
,
features
,
all_results
,
...
...
@@ -350,13 +351,11 @@ def evaluate(args, model, tokenizer, prefix=""):
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
,
tokenizer
,
)
# Evaluate with the official SQuAD script
evaluate_options
=
EVAL_OPTS
(
data_file
=
args
.
predict_file
,
pred_file
=
output_prediction_file
,
na_prob_file
=
output_null_log_odds_file
)
results
=
evaluate_on_squad
(
evaluate_options
)
# Compute the F1 and exact scores.
results
=
squad_evaluate
(
examples
,
predictions
)
return
results
...
...
@@ -368,59 +367,51 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
input_file
=
args
.
predict_file
if
evaluate
else
args
.
train_file
cached_features_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
input_file
),
"cached_{}_{}_{}"
.
format
(
"cached_
distillation_
{}_{}_{}"
.
format
(
"dev"
if
evaluate
else
"train"
,
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
"/"
))).
pop
(),
str
(
args
.
max_seq_length
),
),
)
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
and
not
output_examples
:
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
:
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
features
=
torch
.
load
(
cached_features_file
)
features_and_dataset
=
torch
.
load
(
cached_features_file
)
try
:
features
,
dataset
,
examples
=
(
features_and_dataset
[
"features"
],
features_and_dataset
[
"dataset"
],
features_and_dataset
[
"examples"
],
)
except
KeyError
:
raise
DeprecationWarning
(
"You seem to be loading features from an older version of this script please delete the "
"file %s in order for it to be created again"
%
cached_features_file
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
input_file
)
examples
=
read_squad_examples
(
input_file
=
input_file
,
is_training
=
not
evaluate
,
version_2_with_negative
=
args
.
version_2_with_negative
)
features
=
convert_examples_to_features
(
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
if
evaluate
:
examples
=
processor
.
get_dev_examples
(
None
,
filename
=
args
.
predict_file
)
else
:
examples
=
processor
.
get_train_examples
(
None
,
filename
=
args
.
train_file
)
features
,
dataset
=
squad_convert_examples_to_features
(
examples
=
examples
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
doc_stride
=
args
.
doc_stride
,
max_query_length
=
args
.
max_query_length
,
is_training
=
not
evaluate
,
return_dataset
=
"pt"
,
)
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
torch
.
save
(
features
,
cached_features_file
)
torch
.
save
(
{
"
features
"
:
features
,
"dataset"
:
dataset
,
"examples"
:
examples
}
,
cached_features_file
)
if
args
.
local_rank
==
0
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_input_mask
=
torch
.
tensor
([
f
.
input_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_segment_ids
=
torch
.
tensor
([
f
.
segment_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
if
evaluate
:
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
,
all_cls_index
,
all_p_mask
)
else
:
all_start_positions
=
torch
.
tensor
([
f
.
start_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_end_positions
=
torch
.
tensor
([
f
.
end_position
for
f
in
features
],
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_start_positions
,
all_end_positions
,
all_cls_index
,
all_p_mask
,
)
if
output_examples
:
return
dataset
,
examples
,
features
return
dataset
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment