Commit 9236fd88 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #9070 from NivekNey:patch-1

PiperOrigin-RevId: 327072144
parents 27bda1fc fbf3470f
......@@ -263,6 +263,7 @@ def run_keras_compile_fit(model_dir,
def get_predictions_and_labels(strategy,
trained_model,
eval_input_fn,
is_regression=False,
return_probs=False):
"""Obtains predictions of trained model on evaluation data.
......@@ -273,6 +274,7 @@ def get_predictions_and_labels(strategy,
strategy: Distribution strategy.
trained_model: Trained model with preloaded weights.
eval_input_fn: Input function for evaluation data.
is_regression: Whether it is a regression task.
return_probs: Whether to return probabilities of classes.
Returns:
......@@ -288,8 +290,11 @@ def get_predictions_and_labels(strategy,
"""Replicated predictions."""
inputs, labels = inputs
logits = trained_model(inputs, training=False)
if not is_regression:
probabilities = tf.nn.softmax(logits)
return probabilities, labels
else:
return logits, labels
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits
......@@ -447,9 +452,10 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
include_sample_weights=include_sample_weights)
if FLAGS.mode == 'predict':
num_labels = input_meta_data.get('num_labels', 1)
with strategy.scope():
classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'])[0]
bert_config, num_labels)[0]
checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or
......@@ -460,7 +466,11 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched()
preds, _ = get_predictions_and_labels(
strategy, classifier_model, eval_input_fn, return_probs=True)
strategy,
classifier_model,
eval_input_fn,
is_regression=(num_labels == 1),
return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment