Commit 37392bef authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382382393
parent a4fd6472
......@@ -73,6 +73,10 @@ flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
flags.DEFINE_string('export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
def main(_):
......@@ -95,8 +99,8 @@ def main(_):
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
export_checkpoint_subdir=FLAGS.export_checkpoint_subdir,
export_saved_model_subdir=FLAGS.export_saved_model_subdir)
if __name__ == '__main__':
......
......@@ -69,7 +69,7 @@ def export_inference_graph(
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = export_dir
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
......@@ -119,6 +119,7 @@ def export_inference_graph(
timestamped=False,
save_options=save_options)
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment