"tests/vscode:/vscode.git/clone" did not exist on "e8284281c1c505d84d39dc6ffd2115d03d50e6f3"
Commit c8f9cf19 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Support multiple prediction files for SQuAD task.

PiperOrigin-RevId: 317253522
parent 8284ea20
......@@ -61,7 +61,11 @@ def define_common_squad_flags():
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
# Predict processing related.
flags.DEFINE_string('predict_file', None,
'Prediction data path with train tfrecords.')
'SQuAD prediction json file path. '
'`predict` mode supports multiple files: one can use '
'wildcard to specify multiple files and it can also be '
'multiple file patterns separated by comma. Note that '
'`eval` mode only supports a single predict file.')
flags.DEFINE_bool(
'do_lower_case', True,
'Whether to lower case the input text. Should be True for uncased '
......@@ -159,22 +163,9 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn
def predict_squad_customized(strategy,
input_meta_data,
bert_config,
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
input_meta_data):
"""Gets a squad model to make predictions."""
with strategy.scope():
# Prediction always uses float32, even if training uses mixed precision.
tf.keras.mixed_precision.experimental.set_policy('float32')
......@@ -188,6 +179,23 @@ def predict_squad_customized(strategy,
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
checkpoint.restore(checkpoint_path).expect_partial()
return squad_model
def predict_squad_customized(strategy,
input_meta_data,
predict_tfrecord_path,
num_steps,
squad_model):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
input_meta_data['max_seq_length'],
FLAGS.predict_batch_size,
is_training=False)
predict_iterator = iter(
strategy.experimental_distribute_datasets_from_function(
predict_dataset_fn))
@tf.function
def predict_step(iterator):
......@@ -287,8 +295,8 @@ def train_squad(strategy,
post_allreduce_callbacks=[clip_by_global_norm_callback])
def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint):
def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
predict_file, squad_model):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
......@@ -296,7 +304,7 @@ def prediction_output_squad(
version_2_with_negative = input_meta_data.get('version_2_with_negative',
False)
eval_examples = squad_lib.read_squad_examples(
input_file=FLAGS.predict_file,
input_file=predict_file,
is_training=False,
version_2_with_negative=version_2_with_negative)
......@@ -337,8 +345,7 @@ def prediction_output_squad(
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(
strategy, input_meta_data, bert_config,
checkpoint, eval_writer.filename, num_steps)
strategy, input_meta_data, eval_writer.filename, num_steps, squad_model)
all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output(
......@@ -356,11 +363,14 @@ def prediction_output_squad(
def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib, version_2_with_negative):
squad_lib, version_2_with_negative, file_prefix=''):
"""Save output to json files."""
output_prediction_file = os.path.join(FLAGS.model_dir, 'predictions.json')
output_nbest_file = os.path.join(FLAGS.model_dir, 'nbest_predictions.json')
output_null_log_odds_file = os.path.join(FLAGS.model_dir, 'null_odds.json')
output_prediction_file = os.path.join(FLAGS.model_dir,
'%spredictions.json' % file_prefix)
output_nbest_file = os.path.join(FLAGS.model_dir,
'%snbest_predictions.json' % file_prefix)
output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
'%snull_odds.json' % file_prefix)
logging.info('Writing predictions to: %s', (output_prediction_file))
logging.info('Writing nbest to: %s', (output_nbest_file))
......@@ -370,6 +380,22 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def _get_matched_files(input_path):
"""Returns all files that matches the input_path."""
input_patterns = input_path.strip().split(',')
all_matched_files = []
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
matched_files = tf.io.gfile.glob(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
else:
all_matched_files.extend(matched_files)
return sorted(all_matched_files)
def predict_squad(strategy,
input_meta_data,
tokenizer,
......@@ -379,11 +405,24 @@ def predict_squad(strategy,
"""Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
all_predict_files = _get_matched_files(FLAGS.predict_file)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
for idx, predict_file in enumerate(all_predict_files):
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, squad_lib, predict_file,
squad_model)
if len(all_predict_files) == 1:
file_prefix = ''
else:
# if predict_file is /path/xquad.ar.json, the `file_prefix` may be
# "xquad.ar-0-"
file_prefix = '%s-' % os.path.splitext(
os.path.basename(all_predict_files[idx]))[0]
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False),
file_prefix)
def eval_squad(strategy,
......@@ -395,9 +434,17 @@ def eval_squad(strategy,
"""Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predict_files = _get_matched_files(FLAGS.predict_file)
if len(all_predict_files) != 1:
raise ValueError('`eval_squad` only supports one predict file, '
'but got %s' % all_predict_files)
squad_model = get_squad_model_to_predict(strategy, bert_config,
init_checkpoint, input_meta_data)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
squad_model)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment