Commit f635ac87 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add support to export savedmodel and tflite for QAT models. Only add...

Add support to export savedmodel and tflite for QAT models. Only add classification and segmentation for now.

PiperOrigin-RevId: 457846906
parent 9dc368bc
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export modules for QAT model serving/inference."""
import tensorflow as tf
from official.projects.qat.vision.modeling import factory as qat_factory
from official.vision.serving import image_classification
from official.vision.serving import semantic_segmentation
class ClassificationModule(image_classification.ClassificationModule):
"""Classification Module."""
def _build_model(self):
model = super()._build_model()
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
return qat_factory.build_qat_classification_model(
model, self.params.task.quantization, input_specs,
self.params.task.model)
class SegmentationModule(semantic_segmentation.SegmentationModule):
"""Segmentation Module."""
def _build_model(self):
model = super()._build_model()
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
return qat_factory.build_qat_segmentation_model(
model, self.params.task.quantization, input_specs)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.qat.vision import registry_imports # pylint: disable=unused-import
from official.projects.qat.vision.serving import export_module
from official.vision import configs
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
_EXPERIMENT = flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.')
_CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None,
'Checkpoint path.')
_CONFIG_FILE = flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
_PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
_BATCH_SIZSE = flags.DEFINE_integer('batch_size', None, 'The batch size.')
_IMAGE_TYPE = flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
_INPUT_IMAGE_SIZE = flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
_EXPORT_CHECKPOINT_SUBDIR = flags.DEFINE_string(
'export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
_EXPORT_SAVED_MODEL_SUBDIR = flags.DEFINE_string(
'export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
_LOG_MODEL_FLOPS_AND_PARAMS = flags.DEFINE_bool(
'log_model_flops_and_params', False,
'If true, logs model flops and parameters.')
_INPUT_NAME = flags.DEFINE_string(
'input_name', None,
'Input tensor name in signature def. Default at None which'
'produces input tensor name `inputs`.')
def main(_):
params = exp_factory.get_exp_config(_EXPERIMENT.value)
for config_file in _CONFIG_FILE.value or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if _PARAMS_OVERRIDE.value:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE.value, is_strict=True)
params.validate()
params.lock()
input_image_size = [int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')]
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module_cls = export_module.ClassificationModule
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module_cls = export_module.SegmentationModule
else:
raise TypeError(f'Export module for {type(params.task)} is not supported.')
module = export_module_cls(
params=params,
batch_size=_BATCH_SIZSE.value,
input_image_size=input_image_size,
input_type=_IMAGE_TYPE.value,
num_channels=3)
export_saved_model_lib.export_inference_graph(
input_type=_IMAGE_TYPE.value,
batch_size=_BATCH_SIZSE.value,
input_image_size=input_image_size,
params=params,
checkpoint_path=_CHECKPOINT_PATH.value,
export_dir=_EXPORT_DIR.value,
export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value,
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
export_module=module,
log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value,
input_name=_INPUT_NAME.value)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binary to convert a saved model to TFLite model for the QAT model."""
from absl import app
from official.projects.qat.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import export_tflite
if __name__ == '__main__':
app.run(export_tflite.main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment