Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c8f9cf19
Commit
c8f9cf19
authored
Jun 18, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 18, 2020
Browse files
Support multiple prediction files for SQuAD task.
PiperOrigin-RevId: 317253522
parent
8284ea20
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
33 deletions
+80
-33
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+80
-33
No files found.
official/nlp/bert/run_squad_helper.py
View file @
c8f9cf19
...
...
@@ -61,7 +61,11 @@ def define_common_squad_flags():
flags
.
DEFINE_integer
(
'train_batch_size'
,
32
,
'Total batch size for training.'
)
# Predict processing related.
flags
.
DEFINE_string
(
'predict_file'
,
None
,
'Prediction data path with train tfrecords.'
)
'SQuAD prediction json file path. '
'`predict` mode supports multiple files: one can use '
'wildcard to specify multiple files and it can also be '
'multiple file patterns separated by comma. Note that '
'`eval` mode only supports a single predict file.'
)
flags
.
DEFINE_bool
(
'do_lower_case'
,
True
,
'Whether to lower case the input text. Should be True for uncased '
...
...
@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return
_dataset_fn
def
predict_squad_customized
(
strategy
,
input_meta_data
,
bert_config
,
checkpoint_path
,
predict_tfrecord_path
,
num_steps
):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn
=
get_dataset_fn
(
predict_tfrecord_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
predict_batch_size
,
is_training
=
False
)
predict_iterator
=
iter
(
strategy
.
experimental_distribute_datasets_from_function
(
predict_dataset_fn
))
def
get_squad_model_to_predict
(
strategy
,
bert_config
,
checkpoint_path
,
input_meta_data
):
"""Gets a squad model to make predictions."""
with
strategy
.
scope
():
# Prediction always uses float32, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
...
...
@@ -188,6 +179,23 @@ def predict_squad_customized(strategy,
logging
.
info
(
'Restoring checkpoints from %s'
,
checkpoint_path
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
squad_model
)
checkpoint
.
restore
(
checkpoint_path
).
expect_partial
()
return
squad_model
def
predict_squad_customized
(
strategy
,
input_meta_data
,
predict_tfrecord_path
,
num_steps
,
squad_model
):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn
=
get_dataset_fn
(
predict_tfrecord_path
,
input_meta_data
[
'max_seq_length'
],
FLAGS
.
predict_batch_size
,
is_training
=
False
)
predict_iterator
=
iter
(
strategy
.
experimental_distribute_datasets_from_function
(
predict_dataset_fn
))
@
tf
.
function
def
predict_step
(
iterator
):
...
...
@@ -287,8 +295,8 @@ def train_squad(strategy,
post_allreduce_callbacks
=
[
clip_by_global_norm_callback
])
def
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
,
checkpoint
):
def
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
squad_lib
,
predict_file
,
squad_model
):
"""Makes predictions for a squad dataset."""
doc_stride
=
input_meta_data
[
'doc_stride'
]
max_query_length
=
input_meta_data
[
'max_query_length'
]
...
...
@@ -296,7 +304,7 @@ def prediction_output_squad(
version_2_with_negative
=
input_meta_data
.
get
(
'version_2_with_negative'
,
False
)
eval_examples
=
squad_lib
.
read_squad_examples
(
input_file
=
FLAGS
.
predict_file
,
input_file
=
predict_file
,
is_training
=
False
,
version_2_with_negative
=
version_2_with_negative
)
...
...
@@ -337,8 +345,7 @@ def prediction_output_squad(
num_steps
=
int
(
dataset_size
/
FLAGS
.
predict_batch_size
)
all_results
=
predict_squad_customized
(
strategy
,
input_meta_data
,
bert_config
,
checkpoint
,
eval_writer
.
filename
,
num_steps
)
strategy
,
input_meta_data
,
eval_writer
.
filename
,
num_steps
,
squad_model
)
all_predictions
,
all_nbest_json
,
scores_diff_json
=
(
squad_lib
.
postprocess_output
(
...
...
@@ -356,11 +363,14 @@ def prediction_output_squad(
def
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
version_2_with_negative
):
squad_lib
,
version_2_with_negative
,
file_prefix
=
''
):
"""Save output to json files."""
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'predictions.json'
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'nbest_predictions.json'
)
output_null_log_odds_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'null_odds.json'
)
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'%spredictions.json'
%
file_prefix
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'%snbest_predictions.json'
%
file_prefix
)
output_null_log_odds_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
file_prefix
,
'%snull_odds.json'
%
file_prefix
)
logging
.
info
(
'Writing predictions to: %s'
,
(
output_prediction_file
))
logging
.
info
(
'Writing nbest to: %s'
,
(
output_nbest_file
))
...
...
@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib
.
write_to_json_files
(
scores_diff_json
,
output_null_log_odds_file
)
def
_get_matched_files
(
input_path
):
"""Returns all files that matches the input_path."""
input_patterns
=
input_path
.
strip
().
split
(
','
)
all_matched_files
=
[]
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
else
:
all_matched_files
.
extend
(
matched_files
)
return
sorted
(
all_matched_files
)
def
predict_squad
(
strategy
,
input_meta_data
,
tokenizer
,
...
...
@@ -379,11 +405,24 @@ def predict_squad(strategy,
"""Get prediction results and evaluate them to hard drive."""
if
init_checkpoint
is
None
:
init_checkpoint
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
all_predict_files
=
_get_matched_files
(
FLAGS
.
predict_file
)
squad_model
=
get_squad_model_to_predict
(
strategy
,
bert_config
,
init_checkpoint
,
input_meta_data
)
for
idx
,
predict_file
in
enumerate
(
all_predict_files
):
all_predictions
,
all_nbest_json
,
scores_diff_json
=
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
,
init_checkpoint
)
strategy
,
input_meta_data
,
tokenizer
,
squad_lib
,
predict_file
,
squad_model
)
if
len
(
all_predict_files
)
==
1
:
file_prefix
=
''
else
:
# if predict_file is /path/xquad.ar.json, the `file_prefix` may be
# "xquad.ar-0-"
file_prefix
=
'%s-'
%
os
.
path
.
splitext
(
os
.
path
.
basename
(
all_predict_files
[
idx
]))[
0
]
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
input_meta_data
.
get
(
'version_2_with_negative'
,
False
))
input_meta_data
.
get
(
'version_2_with_negative'
,
False
),
file_prefix
)
def
eval_squad
(
strategy
,
...
...
@@ -395,9 +434,17 @@ def eval_squad(strategy,
"""Get prediction results and evaluate them against ground truth."""
if
init_checkpoint
is
None
:
init_checkpoint
=
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
)
all_predict_files
=
_get_matched_files
(
FLAGS
.
predict_file
)
if
len
(
all_predict_files
)
!=
1
:
raise
ValueError
(
'`eval_squad` only supports one predict file, '
'but got %s'
%
all_predict_files
)
squad_model
=
get_squad_model_to_predict
(
strategy
,
bert_config
,
init_checkpoint
,
input_meta_data
)
all_predictions
,
all_nbest_json
,
scores_diff_json
=
prediction_output_squad
(
strategy
,
input_meta_data
,
tokenizer
,
bert_config
,
squad_lib
,
init_checkpoint
)
strategy
,
input_meta_data
,
tokenizer
,
squad_lib
,
all_predict_files
[
0
],
squad_model
)
dump_to_files
(
all_predictions
,
all_nbest_json
,
scores_diff_json
,
squad_lib
,
input_meta_data
.
get
(
'version_2_with_negative'
,
False
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment