Commit 7cc0970b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 308563418
parent f852bb33
...@@ -40,23 +40,21 @@ flags.DEFINE_string("export_path", None, ...@@ -40,23 +40,21 @@ flags.DEFINE_string("export_path", None,
def export_tfhub(model_path, hub_destination, model_name): def export_tfhub(model_path, hub_destination, model_name):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
model = efficientnet_model.EfficientNet.from_name(model_name) model_configs = dict(efficientnet_model.MODEL_CONFIGS)
ckpt = tf.train.Checkpoint(model=model) config = model_configs[model_name]
ckpt.restore(model_path).assert_existing_objects_matched()
image_input = tf.keras.layers.Input( image_input = tf.keras.layers.Input(
shape=(None, None, 3), name="image_input", dtype=tf.float32) shape=(None, None, 3), name="image_input", dtype=tf.float32)
x = image_input * 255.0 x = image_input * 255.0
ouputs = model(x) ouputs = efficientnet_model.efficientnet(x, config)
hub_model = tf.keras.Model(image_input, ouputs) hub_model = tf.keras.Model(image_input, ouputs)
# Exports a SavedModel. ckpt = tf.train.Checkpoint(model=hub_model)
ckpt.restore(model_path).assert_existing_objects_matched()
hub_model.save( hub_model.save(
os.path.join(hub_destination, "classification"), include_optimizer=False) os.path.join(hub_destination, "classification"), include_optimizer=False)
feature_vector_output = hub_model.get_layer(name="efficientnet").get_layer( feature_vector_output = hub_model.get_layer(name="top_pool").get_output_at(0)
name="top_pool").get_output_at(0) hub_model2 = tf.keras.Model(image_input, feature_vector_output)
hub_model2 = tf.keras.Model(model.inputs, feature_vector_output)
# Exports a SavedModel.
hub_model2.save( hub_model2.save(
os.path.join(hub_destination, "feature-vector"), include_optimizer=False) os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
...@@ -67,6 +65,5 @@ def main(argv): ...@@ -67,6 +65,5 @@ def main(argv):
export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name) export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
if __name__ == "__main__": if __name__ == "__main__":
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment