Commit 24ae1f51 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Expose flag to log model flops and parameters.

PiperOrigin-RevId: 435413594
parent 77bf83b4
...@@ -46,9 +46,8 @@ from official.vision.serving import export_saved_model_lib ...@@ -46,9 +46,8 @@ from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('experiment', None,
flags.DEFINE_string( 'experiment type, e.g. retinanet_resnetfpn_coco')
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.') flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.') flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string( flags.DEFINE_multi_string(
...@@ -64,8 +63,7 @@ flags.DEFINE_string( ...@@ -64,8 +63,7 @@ flags.DEFINE_string(
'params_override', '', 'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden' 'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.') ' on top of `config_file` template.')
flags.DEFINE_integer( flags.DEFINE_integer('batch_size', None, 'The batch size.')
'batch_size', None, 'The batch size.')
flags.DEFINE_string( flags.DEFINE_string(
'input_type', 'image_tensor', 'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.') 'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
...@@ -77,6 +75,8 @@ flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint', ...@@ -77,6 +75,8 @@ flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.') 'The subdirectory for checkpoints.')
flags.DEFINE_string('export_saved_model_subdir', 'saved_model', flags.DEFINE_string('export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.') 'The subdirectory for saved model.')
flags.DEFINE_bool('log_model_flops_and_params', False,
'If true, logs model flops and parameters.')
def main(_): def main(_):
...@@ -100,7 +100,8 @@ def main(_): ...@@ -100,7 +100,8 @@ def main(_):
checkpoint_path=FLAGS.checkpoint_path, checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir, export_dir=FLAGS.export_dir,
export_checkpoint_subdir=FLAGS.export_checkpoint_subdir, export_checkpoint_subdir=FLAGS.export_checkpoint_subdir,
export_saved_model_subdir=FLAGS.export_saved_model_subdir) export_saved_model_subdir=FLAGS.export_saved_model_subdir,
log_model_flops_and_params=FLAGS.log_model_flops_and_params)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment