Unverified Commit fbf3470f authored by Kevin's avatar Kevin Committed by GitHub
Browse files

Skip softmax activation if num_labels <= 1

parent 83f355bb
...@@ -262,6 +262,7 @@ def run_keras_compile_fit(model_dir, ...@@ -262,6 +262,7 @@ def run_keras_compile_fit(model_dir,
def get_predictions_and_labels(strategy, def get_predictions_and_labels(strategy,
trained_model, trained_model,
eval_input_fn, eval_input_fn,
num_labels,
return_probs=False): return_probs=False):
"""Obtains predictions of trained model on evaluation data. """Obtains predictions of trained model on evaluation data.
...@@ -287,8 +288,11 @@ def get_predictions_and_labels(strategy, ...@@ -287,8 +288,11 @@ def get_predictions_and_labels(strategy,
"""Replicated predictions.""" """Replicated predictions."""
inputs, labels = inputs inputs, labels = inputs
logits = trained_model(inputs, training=False) logits = trained_model(inputs, training=False)
probabilities = tf.nn.softmax(logits) if num_labels > 1:
return probabilities, labels probabilities = tf.nn.softmax(logits)
return probabilities, labels
else:
return logits, labels
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),)) outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
# outputs: current batch logits as a tuple of shard logits # outputs: current batch logits as a tuple of shard logits
...@@ -446,9 +450,10 @@ def custom_main(custom_callbacks=None, custom_metrics=None): ...@@ -446,9 +450,10 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
include_sample_weights=include_sample_weights) include_sample_weights=include_sample_weights)
if FLAGS.mode == 'predict': if FLAGS.mode == 'predict':
num_labels = input_meta_data.get('num_labels', 1)
with strategy.scope(): with strategy.scope():
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data.get('num_labels', 1))[0] bert_config, num_labels)[0]
checkpoint = tf.train.Checkpoint(model=classifier_model) checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = ( latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or FLAGS.predict_checkpoint_path or
...@@ -459,7 +464,8 @@ def custom_main(custom_callbacks=None, custom_metrics=None): ...@@ -459,7 +464,8 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
checkpoint.restore( checkpoint.restore(
latest_checkpoint_file).assert_existing_objects_matched() latest_checkpoint_file).assert_existing_objects_matched()
preds, _ = get_predictions_and_labels( preds, _ = get_predictions_and_labels(
strategy, classifier_model, eval_input_fn, return_probs=True) strategy, classifier_model, eval_input_fn,
num_labels, return_probs=True)
output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv') output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
with tf.io.gfile.GFile(output_predict_file, 'w') as writer: with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
logging.info('***** Predict results *****') logging.info('***** Predict results *****')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment