Commit 37392bef authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 382382393
parent a4fd6472
...@@ -73,6 +73,10 @@ flags.DEFINE_string( ...@@ -73,6 +73,10 @@ flags.DEFINE_string(
'input_image_size', '224,224', 'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width ' 'The comma-separated string of two integers representing the height,width '
'of the input to the model.') 'of the input to the model.')
flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
flags.DEFINE_string('export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
def main(_): def main(_):
...@@ -95,8 +99,8 @@ def main(_): ...@@ -95,8 +99,8 @@ def main(_):
params=params, params=params,
checkpoint_path=FLAGS.checkpoint_path, checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir, export_dir=FLAGS.export_dir,
export_checkpoint_subdir='checkpoint', export_checkpoint_subdir=FLAGS.export_checkpoint_subdir,
export_saved_model_subdir='saved_model') export_saved_model_subdir=FLAGS.export_saved_model_subdir)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -69,7 +69,7 @@ def export_inference_graph( ...@@ -69,7 +69,7 @@ def export_inference_graph(
output_checkpoint_directory = os.path.join( output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir) export_dir, export_checkpoint_subdir)
else: else:
output_checkpoint_directory = export_dir output_checkpoint_directory = None
if export_saved_model_subdir: if export_saved_model_subdir:
output_saved_model_directory = os.path.join( output_saved_model_directory = os.path.join(
...@@ -119,6 +119,7 @@ def export_inference_graph( ...@@ -119,6 +119,7 @@ def export_inference_graph(
timestamped=False, timestamped=False,
save_options=save_options) save_options=save_options)
ckpt = tf.train.Checkpoint(model=export_module.model) if output_checkpoint_directory:
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt')) ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir) train_utils.serialize_config(params, export_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment