Commit 0214f22b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443781710
parent 5e5e6f6e
...@@ -32,7 +32,8 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -32,7 +32,8 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
input_image_size: List[int], input_image_size: List[int],
input_type: str = 'image_tensor', input_type: str = 'image_tensor',
num_channels: int = 3, num_channels: int = 3,
model: Optional[tf.keras.Model] = None): model: Optional[tf.keras.Model] = None,
input_name: Optional[str] = None):
"""Initializes a module for export. """Initializes a module for export.
Args: Args:
...@@ -43,12 +44,14 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -43,12 +44,14 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
input_type: The input signature type. input_type: The input signature type.
num_channels: The number of the image channels. num_channels: The number of the image channels.
model: A tf.keras.Model instance to be exported. model: A tf.keras.Model instance to be exported.
input_name: A customized input tensor name.
""" """
self.params = params self.params = params
self._batch_size = batch_size self._batch_size = batch_size
self._input_image_size = input_image_size self._input_image_size = input_image_size
self._num_channels = num_channels self._num_channels = num_channels
self._input_type = input_type self._input_type = input_type
self._input_name = input_name
if model is None: if model is None:
model = self._build_model() # pylint: disable=assignment-from-none model = self._build_model() # pylint: disable=assignment-from-none
super().__init__(params=params, model=model) super().__init__(params=params, model=model)
...@@ -163,19 +166,20 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -163,19 +166,20 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[self._batch_size] + [None] * len(self._input_image_size) + shape=[self._batch_size] + [None] * len(self._input_image_size) +
[self._num_channels], [self._num_channels],
dtype=tf.uint8) dtype=tf.uint8,
name=self._input_name)
signatures[ signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function( def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature) input_signature)
elif key == 'image_bytes': elif key == 'image_bytes':
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string) shape=[self._batch_size], dtype=tf.string, name=self._input_name)
signatures[ signatures[
def_name] = self.inference_from_image_bytes.get_concrete_function( def_name] = self.inference_from_image_bytes.get_concrete_function(
input_signature) input_signature)
elif key == 'serve_examples' or key == 'tf_example': elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string) shape=[self._batch_size], dtype=tf.string, name=self._input_name)
signatures[ signatures[
def_name] = self.inference_from_tf_example.get_concrete_function( def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature) input_signature)
...@@ -183,7 +187,8 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta): ...@@ -183,7 +187,8 @@ class ExportModule(export_base.ExportModule, metaclass=abc.ABCMeta):
input_signature = tf.TensorSpec( input_signature = tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size + shape=[self._batch_size] + self._input_image_size +
[self._num_channels], [self._num_channels],
dtype=tf.float32) dtype=tf.float32,
name=self._input_name)
signatures[def_name] = self.inference_for_tflite.get_concrete_function( signatures[def_name] = self.inference_for_tflite.get_concrete_function(
input_signature) input_signature)
else: else:
......
...@@ -45,11 +45,12 @@ from official.vision.serving import export_saved_model_lib ...@@ -45,11 +45,12 @@ from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string('experiment', None, _EXPERIMENT = flags.DEFINE_string(
'experiment type, e.g. retinanet_resnetfpn_coco') 'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.') _EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.') _CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None,
flags.DEFINE_multi_string( 'Checkpoint path.')
_CONFIG_FILE = flags.DEFINE_multi_string(
'config_file', 'config_file',
default=None, default=None,
help='YAML/JSON files which specifies overrides. The override order ' help='YAML/JSON files which specifies overrides. The override order '
...@@ -58,49 +59,57 @@ flags.DEFINE_multi_string( ...@@ -58,49 +59,57 @@ flags.DEFINE_multi_string(
'specified in Python. If the same parameter is specified in both ' 'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used ' '`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.') 'first, followed by params_override.')
flags.DEFINE_string( _PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override', '', 'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden' 'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.') ' on top of `config_file` template.')
flags.DEFINE_integer('batch_size', None, 'The batch size.') _BATCH_SIZSE = flags.DEFINE_integer('batch_size', None, 'The batch size.')
flags.DEFINE_string( _IMAGE_TYPE = flags.DEFINE_string(
'input_type', 'image_tensor', 'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.') 'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
flags.DEFINE_string( _INPUT_IMAGE_SIZE = flags.DEFINE_string(
'input_image_size', '224,224', 'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width ' 'The comma-separated string of two integers representing the height,width '
'of the input to the model.') 'of the input to the model.')
flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint', _EXPORT_CHECKPOINT_SUBDIR = flags.DEFINE_string(
'The subdirectory for checkpoints.') 'export_checkpoint_subdir', 'checkpoint',
flags.DEFINE_string('export_saved_model_subdir', 'saved_model', 'The subdirectory for checkpoints.')
'The subdirectory for saved model.') _EXPORT_SAVED_MODEL_SUBDIR = flags.DEFINE_string(
flags.DEFINE_bool('log_model_flops_and_params', False, 'export_saved_model_subdir', 'saved_model',
'If true, logs model flops and parameters.') 'The subdirectory for saved model.')
_LOG_MODEL_FLOPS_AND_PARAMS = flags.DEFINE_bool(
'log_model_flops_and_params', False,
'If true, logs model flops and parameters.')
_INPUT_NAME = flags.DEFINE_string(
'input_name', None,
'Input tensor name in signature def. Default at None which'
'produces input tensor name `inputs`.')
def main(_): def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment) params = exp_factory.get_exp_config(_EXPERIMENT.value)
for config_file in FLAGS.config_file or []: for config_file in _CONFIG_FILE.value or []:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, config_file, is_strict=True) params, config_file, is_strict=True)
if FLAGS.params_override: if _PARAMS_OVERRIDE.value:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True) params, _PARAMS_OVERRIDE.value, is_strict=True)
params.validate() params.validate()
params.lock() params.lock()
export_saved_model_lib.export_inference_graph( export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type, input_type=_IMAGE_TYPE.value,
batch_size=FLAGS.batch_size, batch_size=_BATCH_SIZSE.value,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')], input_image_size=[int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')],
params=params, params=params,
checkpoint_path=FLAGS.checkpoint_path, checkpoint_path=_CHECKPOINT_PATH.value,
export_dir=FLAGS.export_dir, export_dir=_EXPORT_DIR.value,
export_checkpoint_subdir=FLAGS.export_checkpoint_subdir, export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value,
export_saved_model_subdir=FLAGS.export_saved_model_subdir, export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
log_model_flops_and_params=FLAGS.log_model_flops_and_params) log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value,
input_name=_INPUT_NAME.value)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -43,7 +43,8 @@ def export_inference_graph( ...@@ -43,7 +43,8 @@ def export_inference_graph(
export_saved_model_subdir: Optional[str] = None, export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None, save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False, log_model_flops_and_params: bool = False,
checkpoint: Optional[tf.train.Checkpoint] = None): checkpoint: Optional[tf.train.Checkpoint] = None,
input_name: Optional[str] = None):
"""Exports inference graph for the model specified in the exp config. """Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved Saved model is stored at export_dir/saved_model, checkpoint is saved
...@@ -69,6 +70,8 @@ def export_inference_graph( ...@@ -69,6 +70,8 @@ def export_inference_graph(
and model parameters to model_params.txt. and model parameters to model_params.txt.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights. will use it to read the weights.
input_name: The input tensor name, default at `None` which produces input
tensor name `inputs`.
""" """
if export_checkpoint_subdir: if export_checkpoint_subdir:
...@@ -92,7 +95,8 @@ def export_inference_graph( ...@@ -92,7 +95,8 @@ def export_inference_graph(
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type, input_type=input_type,
num_channels=num_channels) num_channels=num_channels,
input_name=input_name)
elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance( elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
params.task, configs.maskrcnn.MaskRCNNTask): params.task, configs.maskrcnn.MaskRCNNTask):
export_module = detection.DetectionModule( export_module = detection.DetectionModule(
...@@ -100,7 +104,8 @@ def export_inference_graph( ...@@ -100,7 +104,8 @@ def export_inference_graph(
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type, input_type=input_type,
num_channels=num_channels) num_channels=num_channels,
input_name=input_name)
elif isinstance(params.task, elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask): configs.semantic_segmentation.SemanticSegmentationTask):
export_module = semantic_segmentation.SegmentationModule( export_module = semantic_segmentation.SegmentationModule(
...@@ -108,7 +113,8 @@ def export_inference_graph( ...@@ -108,7 +113,8 @@ def export_inference_graph(
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type, input_type=input_type,
num_channels=num_channels) num_channels=num_channels,
input_name=input_name)
elif isinstance(params.task, elif isinstance(params.task,
configs.video_classification.VideoClassificationTask): configs.video_classification.VideoClassificationTask):
export_module = video_classification.VideoClassificationModule( export_module = video_classification.VideoClassificationModule(
...@@ -116,7 +122,8 @@ def export_inference_graph( ...@@ -116,7 +122,8 @@ def export_inference_graph(
batch_size=batch_size, batch_size=batch_size,
input_image_size=input_image_size, input_image_size=input_image_size,
input_type=input_type, input_type=input_type,
num_channels=num_channels) num_channels=num_channels,
input_name=input_name)
else: else:
raise ValueError('Export module not implemented for {} task.'.format( raise ValueError('Export module not implemented for {} task.'.format(
type(params.task))) type(params.task)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment