Unverified Commit 83f355bb authored by Kevin's avatar Kevin Committed by GitHub
Browse files

Consider regression model in predict mode

parent 0ba5a72b
...@@ -448,7 +448,7 @@ def custom_main(custom_callbacks=None, custom_metrics=None): ...@@ -448,7 +448,7 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
if FLAGS.mode == 'predict': if FLAGS.mode == 'predict':
with strategy.scope(): with strategy.scope():
classifier_model = bert_models.classifier_model( classifier_model = bert_models.classifier_model(
bert_config, input_meta_data['num_labels'])[0] bert_config, input_meta_data.get('num_labels', 1))[0]
checkpoint = tf.train.Checkpoint(model=classifier_model) checkpoint = tf.train.Checkpoint(model=classifier_model)
latest_checkpoint_file = ( latest_checkpoint_file = (
FLAGS.predict_checkpoint_path or FLAGS.predict_checkpoint_path or
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment