Commit 614d3d93 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 404687702
parent 56c34e5f
...@@ -50,6 +50,11 @@ flags.DEFINE_enum( ...@@ -50,6 +50,11 @@ flags.DEFINE_enum(
['tf_checkpoint', 'keras_checkpoint'], ['tf_checkpoint', 'keras_checkpoint'],
'tf_checkpoint is for ckpt files from tf.train.Checkpoint.save() method' 'tf_checkpoint is for ckpt files from tf.train.Checkpoint.save() method'
'keras_checkpoint is for ckpt files from keras.Model.save_weights() method') 'keras_checkpoint is for ckpt files from keras.Model.save_weights() method')
flags.DEFINE_bool(
'export_keras_model', False,
'Export SavedModel format: if False, export TF SavedModel with'
'tf.saved_model API; if True, export Keras SavedModel with tf.keras.Model'
'API.')
flags.DEFINE_string('output_dir', None, 'Directory to output exported files.') flags.DEFINE_string('output_dir', None, 'Directory to output exported files.')
flags.DEFINE_integer( flags.DEFINE_integer(
'image_size', 224, 'image_size', 224,
...@@ -161,7 +166,10 @@ def run_export(): ...@@ -161,7 +166,10 @@ def run_export():
# Export saved model. # Export saved model.
saved_model_path = os.path.join(export_config.output_dir, saved_model_path = os.path.join(export_config.output_dir,
export_config.model_name) export_config.model_name)
model_for_inference.save(saved_model_path) if FLAGS.export_keras_model:
model_for_inference.save(saved_model_path)
else:
tf.saved_model.save(model_for_inference, saved_model_path)
print('SavedModel exported to {}'.format(saved_model_path)) print('SavedModel exported to {}'.format(saved_model_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment