Commit 63e20270 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update resnet hub export

PiperOrigin-RevId: 274917111
parent 048f5a95
...@@ -19,6 +19,8 @@ from __future__ import division ...@@ -19,6 +19,8 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import os
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -39,14 +41,17 @@ def export_tfhub(model_path, hub_destination): ...@@ -39,14 +41,17 @@ def export_tfhub(model_path, hub_destination):
"""Restores a tf.keras.Model and saves for TF-Hub.""" """Restores a tf.keras.Model and saves for TF-Hub."""
model = resnet_model.resnet50(num_classes=imagenet_preprocessing.NUM_CLASSES) model = resnet_model.resnet50(num_classes=imagenet_preprocessing.NUM_CLASSES)
model.load_weights(model_path) model.load_weights(model_path)
model.save(os.path.join(hub_destination, "classification"), include_optimizer=False)
# Extracts a sub-model to use pooling feature vector as model output. # Extracts a sub-model to use pooling feature vector as model output.
image_input = model.get_layer(index=0).get_output_at(0) image_input = model.get_layer(index=0).get_output_at(0)
feature_vector_output = model.get_layer(name='reduce_mean').get_output_at(0) feature_vector_output = model.get_layer(name="reduce_mean").get_output_at(0)
hub_model = tf.keras.Model(image_input, feature_vector_output) hub_model = tf.keras.Model(image_input, feature_vector_output)
# Exports a SavedModel. # Exports a SavedModel.
hub_model.save(hub_destination, include_optimizer=False) hub_model.save(
os.path.join(hub_destination, "feature-vector"),
include_optimizer=False)
def main(argv): def main(argv):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment