Commit 12acb414 authored by Pengchong Jin's avatar Pengchong Jin Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 358194199
parent 852e098a
...@@ -36,6 +36,7 @@ class ClassificationModel(tf.keras.Model): ...@@ -36,6 +36,7 @@ class ClassificationModel(tf.keras.Model):
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
skip_logits_layer: bool = False,
**kwargs): **kwargs):
"""Classification initialization function. """Classification initialization function.
...@@ -55,6 +56,7 @@ class ClassificationModel(tf.keras.Model): ...@@ -55,6 +56,7 @@ class ClassificationModel(tf.keras.Model):
norm_momentum: `float` normalization momentum for the moving average. norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by norm_epsilon: `float` small float added to variance to avoid dividing by
zero. zero.
skip_logits_layer: `bool`, whether to skip the prediction layer.
**kwargs: keyword arguments to be passed. **kwargs: keyword arguments to be passed.
""" """
self._self_setattr_tracking = False self._self_setattr_tracking = False
...@@ -88,12 +90,13 @@ class ClassificationModel(tf.keras.Model): ...@@ -88,12 +90,13 @@ class ClassificationModel(tf.keras.Model):
if add_head_batch_norm: if add_head_batch_norm:
x = self._norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x) x = self._norm(axis=axis, momentum=norm_momentum, epsilon=norm_epsilon)(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(dropout_rate)(x) if not skip_logits_layer:
x = tf.keras.layers.Dense( x = tf.keras.layers.Dropout(dropout_rate)(x)
num_classes, kernel_initializer=kernel_initializer, x = tf.keras.layers.Dense(
kernel_regularizer=self._kernel_regularizer, num_classes, kernel_initializer=kernel_initializer,
bias_regularizer=self._bias_regularizer)( kernel_regularizer=self._kernel_regularizer,
x) bias_regularizer=self._bias_regularizer)(
x)
super(ClassificationModel, self).__init__( super(ClassificationModel, self).__init__(
inputs=inputs, outputs=x, **kwargs) inputs=inputs, outputs=x, **kwargs)
......
...@@ -41,7 +41,8 @@ from official.vision.beta.modeling.layers import roi_sampler ...@@ -41,7 +41,8 @@ from official.vision.beta.modeling.layers import roi_sampler
def build_classification_model( def build_classification_model(
input_specs: tf.keras.layers.InputSpec, input_specs: tf.keras.layers.InputSpec,
model_config: classification_cfg.ImageClassificationModel, model_config: classification_cfg.ImageClassificationModel,
l2_regularizer: tf.keras.regularizers.Regularizer = None): l2_regularizer: tf.keras.regularizers.Regularizer = None,
skip_logits_layer: bool = False):
"""Builds the classification model.""" """Builds the classification model."""
backbone = backbones.factory.build_backbone( backbone = backbones.factory.build_backbone(
input_specs=input_specs, input_specs=input_specs,
...@@ -58,7 +59,8 @@ def build_classification_model( ...@@ -58,7 +59,8 @@ def build_classification_model(
add_head_batch_norm=model_config.add_head_batch_norm, add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon) norm_epsilon=norm_activation_config.norm_epsilon,
skip_logits_layer=skip_logits_layer)
return model return model
......
...@@ -53,7 +53,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta): ...@@ -53,7 +53,7 @@ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
self._model = model self._model = model
@abc.abstractmethod @abc.abstractmethod
def build_model(self): def build_model(self, **kwargs):
"""Builds model and sets self._model.""" """Builds model and sets self._model."""
@abc.abstractmethod @abc.abstractmethod
......
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export the image classification as a TF-Hub SavedModel."""
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.serving import image_classification
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. resnet_imagenet')
flags.DEFINE_string(
'checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_string(
'export_path', None, 'The export directory.')
flags.DEFINE_multi_string(
'config_file',
None,
'A YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_image_size',
'224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_boolean(
'skip_logits_layer',
False,
'Whether to skip the prediction layer and only output the feature vector.')
def export_model_to_tfhub(params,
batch_size,
input_image_size,
skip_logits_layer,
checkpoint_path,
export_path):
"""Export an image classification model to TF-Hub."""
export_module = image_classification.ClassificationModule(
params=params, batch_size=batch_size, input_image_size=input_image_size)
model = export_module.build_model(skip_logits_layer=skip_logits_layer)
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(checkpoint_path).assert_existing_objects_matched()
model.save(export_path, include_optimizer=False, save_format='tf')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_model_to_tfhub(
params=params,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
skip_logits_layer=FLAGS.skip_logits_layer,
checkpoint_path=FLAGS.checkpoint_path,
export_path=FLAGS.export_path)
if __name__ == '__main__':
app.run(main)
...@@ -29,14 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) ...@@ -29,14 +29,15 @@ STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class ClassificationModule(export_base.ExportModule): class ClassificationModule(export_base.ExportModule):
"""classification Module.""" """classification Module."""
def build_model(self): def build_model(self, skip_logits_layer=False):
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3]) shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_classification_model( self._model = factory.build_classification_model(
input_specs=input_specs, input_specs=input_specs,
model_config=self._params.task.model, model_config=self._params.task.model,
l2_regularizer=None) l2_regularizer=None,
skip_logits_layer=skip_logits_layer)
return self._model return self._model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment