"backends/v2/src/client/mod.rs" did not exist on "9af454142a34536ab1f3c149cc8764b7ab460c0d"
Commit 2ff0a3fe authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 305332859
parent ce8dd972
......@@ -47,11 +47,13 @@ FLAGS = flags.FLAGS
def train_squad(strategy,
input_meta_data,
custom_callbacks=None,
run_eagerly=False):
run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training."""
bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
custom_callbacks, run_eagerly)
custom_callbacks, run_eagerly, init_checkpoint)
def predict_squad(strategy, input_meta_data):
......
......@@ -159,8 +159,12 @@ def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
return _dataset_fn
def predict_squad_customized(strategy, input_meta_data, bert_config,
predict_tfrecord_path, num_steps):
def predict_squad_customized(strategy,
input_meta_data,
bert_config,
checkpoint_path,
predict_tfrecord_path,
num_steps):
"""Make predictions using a Bert-based squad model."""
predict_dataset_fn = get_dataset_fn(
predict_tfrecord_path,
......@@ -179,6 +183,7 @@ def predict_squad_customized(strategy, input_meta_data, bert_config,
input_meta_data['max_seq_length'],
hub_module_url=FLAGS.hub_module_url)
if checkpoint_path is None:
checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
logging.info('Restoring checkpoints from %s', checkpoint_path)
checkpoint = tf.train.Checkpoint(model=squad_model)
......@@ -215,7 +220,8 @@ def train_squad(strategy,
input_meta_data,
bert_config,
custom_callbacks=None,
run_eagerly=False):
run_eagerly=False,
init_checkpoint=None):
"""Run bert squad training."""
if strategy:
logging.info('Training using customized training loop with distribution'
......@@ -271,7 +277,7 @@ def train_squad(strategy,
steps_per_loop=FLAGS.steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
run_eagerly=run_eagerly,
custom_callbacks=custom_callbacks,
explicit_allreduce=False,
......@@ -279,7 +285,7 @@ def train_squad(strategy,
def prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib):
strategy, input_meta_data, tokenizer, bert_config, squad_lib, checkpoint):
"""Makes predictions for a squad dataset."""
doc_stride = input_meta_data['doc_stride']
max_query_length = input_meta_data['max_query_length']
......@@ -327,8 +333,9 @@ def prediction_output_squad(
logging.info(' Batch size = %d', FLAGS.predict_batch_size)
num_steps = int(dataset_size / FLAGS.predict_batch_size)
all_results = predict_squad_customized(strategy, input_meta_data, bert_config,
eval_writer.filename, num_steps)
all_results = predict_squad_customized(
strategy, input_meta_data, bert_config,
checkpoint, eval_writer.filename, num_steps)
all_predictions, all_nbest_json, scores_diff_json = (
squad_lib.postprocess_output(
......@@ -360,18 +367,34 @@ def dump_to_files(all_predictions, all_nbest_json, scores_diff_json,
squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
def predict_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
def predict_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them to hard drive."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
def eval_squad(strategy, input_meta_data, tokenizer, bert_config, squad_lib):
def eval_squad(strategy,
input_meta_data,
tokenizer,
bert_config,
squad_lib,
init_checkpoint=None):
"""Get prediction results and evaluate them against ground truth."""
if init_checkpoint is None:
init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
strategy, input_meta_data, tokenizer, bert_config, squad_lib)
strategy, input_meta_data, tokenizer,
bert_config, squad_lib, init_checkpoint)
dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
input_meta_data.get('version_2_with_negative', False))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment