Commit c8f9cf19 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support multiple prediction files for SQuAD task.

PiperOrigin-RevId: 317253522
parent 8284ea20
...@@ -61,7 +61,11 @@ def define_common_squad_flags(): ...@@ -61,7 +61,11 @@ def define_common_squad_flags():
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related. # Predict processing related.
flags.DEFINE_string('predict_file', None, flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.') 'SQuAD prediction json file path. '
'`predict` mode supports multiple files: one can use '
'wildcard to specify multiple files and it can also be '
'multiple file patterns separated by comma. Note that '
'`eval` mode only supports a single predict file.')
flags.DEFINE_bool( flags.DEFINE_bool(
'do_lower_case', True, 'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased ' 'Whether to lower case the input text. Should be True for uncased '
...@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, ...@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn return _dataset_fn
def predict_squad_customized(strategy, def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
input_meta_data, input_meta_data):
bert_config, """Gets a squad model to make predictions."""
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
with strategy.scope(): with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision. # Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32') tf.keras.mixed_precision.experimental.set_policy('float32')
...@@ -188,6 +179,23 @@ def predict_squad_customized(strategy, ...@@ -188,6 +179,23 @@ def predict_squad_customized(strategy,
logging.info('Restoring checkpoints from %s', checkpoint_path) logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model) checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial() checkpoint.restore(checkpoint_path).expect_partial()
return squad_model
def predict_squad_customized(strategy,
input_meta_data,
predict_tfrecord_path,
num_steps,
squad_model):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
@tf.function @tf.function
def predict_step(iterator): def predict_step(iterator):
...@@ -287,8 +295,8 @@ def train_squad(strategy, ...@@ -287,8 +295,8 @@ def train_squad(strategy,
post_allreduce_callbacks=[clip_by_global_norm_callback]) post_allreduce_callbacks=[clip_by_global_norm_callback])
def prediction_output_squad( def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint): predict_file, squad_model):
"""Makes predictions for a squad dataset.""" """Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride'] doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length'] max_query_length = input_meta_data['max_query_length']
...@@ -296,7 +304,7 @@ def prediction_output_squad( ...@@ -296,7 +304,7 @@ def prediction_output_squad(
version_2_with_negative = input_meta_data.get('version_2_with_negative', version_2_with_negative = input_meta_data.get('version_2_with_negative',
False) False)
eval_examples = squad_lib.read_squad_examples( eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file, input_file=predict_file,
is_training=False, is_training=False,
version_2_with_negative=version_2_with_negative) version_2_with_negative=version_2_with_negative)
...@@ -337,8 +345,7 @@ def prediction_output_squad( ...@@ -337,8 +345,7 @@ def prediction_output_squad(
num_steps = int(dataset_size / FLAGS.predict_batch_size) num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized( all_results = predict_squad_customized(
strategy, input_meta_data, bert_config, strategy, input_meta_data, eval_writer.filename, num_steps, squad_model)
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = ( all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output( squad_lib.postprocess_output(
...@@ -356,11 +363,14 @@ def prediction_output_squad( ...@@ -356,11 +363,14 @@ def prediction_output_squad(
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative): squad_lib, version_2_with_negative, file_prefix=''):
"""Save output to json files.""" """Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json') output_prediction_file = os.path.join(FLAGS.model_dir,
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json') '%spredictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json') output_nbest_file = os.path.join(FLAGS.model_dir,
'%snbest_predictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
'%snull_odds.json' % file_prefix)
logging.info('Writing predictions to: %s', (output_prediction_file)) logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file)) logging.info('Writing nbest to: %s', (output_nbest_file))
...@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json, ...@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file) squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def _get_matched_files(input_path):
"""Returns all files that matches the input_path."""
input_patterns = input_path.strip().split(',')
all_matched_files = []
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
all_matched_files.extend(matched_files)
return sorted(all_matched_files)
def predict_squad(strategy, def predict_squad(strategy,
input_meta_data, input_meta_data,
tokenizer, tokenizer,
...@@ -379,11 +405,24 @@ def predict_squad(strategy, ...@@ -379,11 +405,24 @@ def predict_squad(strategy,
"""Get prediction results and evaluate them to hard drive.""" """Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None: if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, all_predict_files = _get_matched_files(FLAGS.predict_file)
bert_config, squad_lib, init_checkpoint) squad_model = get_squad_model_to_predict(strategy, bert_config,
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, init_checkpoint, input_meta_data)
input_meta_data.get('version_2_with_negative', False)) for idx, predict_file in enumerate(all_predict_files):
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, squad_lib, predict_file,
squad_model)
if len(all_predict_files) == 1:
file_prefix = ''
else:
# if predict_file is /path/xquad.ar.json, the `file_prefix` may be
# "xquad.ar-0-"
file_prefix = '%s-' % os.path.splitext(
os.path.basename(all_predict_files[idx]))[0]
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False),
file_prefix)
def eval_squad(strategy, def eval_squad(strategy,
...@@ -395,9 +434,17 @@ def eval_squad(strategy, ...@@ -395,9 +434,17 @@ def eval_squad(strategy,
"""Get prediction results and evaluate them against ground truth.""" """Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None: if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predict_files = _get_matched_files(FLAGS.predict_file)
if len(all_predict_files) != 1:
raise ValueError('`eval_squad` only supports one predict file, '
'but got %s' % all_predict_files)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad( all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
bert_config, squad_lib, init_checkpoint) squad_model)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib, dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False)) input_meta_data.get('version_2_with_negative', False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment