Unverified Commit 44f6d511 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 686a287d 8bc5a1a5
......@@ -125,6 +125,8 @@ class TrainTest(parameterized.TestCase, tf.test.TestCase):
interaction=interaction,
use_orbit=use_orbit,
strategy=strategy)
default_mode = FLAGS.mode
# Training.
FLAGS.mode = 'train'
train.main('unused_args')
......@@ -134,6 +136,7 @@ class TrainTest(parameterized.TestCase, tf.test.TestCase):
# Evaluation.
FLAGS.mode = 'eval'
train.main('unused_args')
FLAGS.mode = default_mode
if __name__ == '__main__':
......
......@@ -26,3 +26,5 @@ pycocotools
seqeval
sentencepiece
sacrebleu
# Projects/vit dependencies
immutabledict
# Docs generation scripts for TensorFlow Models
The scripts here are used to generate api-reference pages for tensorflow.org.
The scripts require tensorflow_docs, which can be installed directly from
github:
```
$> pip install -U git+https://github.com/tensorflow/docs
$> python build_all_api_docs.py --output_dir=/tmp/tfm_docs
```
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.tools.build_docs."""
import os
import shutil
import tensorflow as tf
from official.utils.docs import build_all_api_docs
class BuildDocsTest(tf.test.TestCase):
def setUp(self):
super(BuildDocsTest, self).setUp()
self.workdir = self.get_temp_dir()
if os.path.exists(self.workdir):
shutil.rmtree(self.workdir)
os.makedirs(self.workdir)
def test_api_gen(self):
build_all_api_docs.gen_api_docs(
code_url_prefix="https://github.com/tensorflow/models/blob/master/tensorflow_models",
site_path="tf_modeling/api_docs/python",
output_dir=self.workdir,
project_short_name="tfm",
project_full_name="TensorFlow Modeling",
search_hints=True)
# Check that the "defined in" section is working
with open(os.path.join(self.workdir, "tfm.md")) as f:
content = f.read()
self.assertIn("__init__.py", content)
if __name__ == "__main__":
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Common library for API docs builder."""
import tensorflow as tf
from tensorflow_docs.api_generator import doc_controls
def hide_module_model_and_layer_methods():
"""Hide methods and properties defined in the base classes of Keras layers.
We hide all methods and properties of the base classes, except:
- `__init__` is always documented.
- `call` is always documented, as it can carry important information for
complex layers.
"""
module_contents = list(tf.Module.__dict__.items())
model_contents = list(tf.keras.Model.__dict__.items())
layer_contents = list(tf.keras.layers.Layer.__dict__.items())
for name, obj in module_contents + layer_contents + model_contents:
if name == '__init__':
# Always document __init__.
continue
if name == 'call':
# Always document `call`.
if hasattr(obj, doc_controls._FOR_SUBCLASS_IMPLEMENTERS): # pylint: disable=protected-access
delattr(obj, doc_controls._FOR_SUBCLASS_IMPLEMENTERS) # pylint: disable=protected-access
continue
# Otherwise, exclude from documentation.
if isinstance(obj, property):
obj = obj.fget
if isinstance(obj, (staticmethod, classmethod)):
obj = obj.__func__
try:
doc_controls.do_not_doc_in_subclasses(obj)
except AttributeError:
pass
......@@ -17,28 +17,25 @@ r"""Tool to generate api_docs for tensorflow_models/official library.
Example:
$> pip install -U git+https://github.com/tensorflow/docs
$> python build_nlp_api_docs \
--output_dir=/tmp/api_docs
$> python build_orbit_api_docs.py --output_dir=/tmp/api_docs
"""
import pathlib
from absl import app
from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from tensorflow_docs.api_generator import doc_controls
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
import tensorflow_models as tfm
import build_api_docs_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('output_dir', None, 'Where to write the resulting docs to.')
flags.DEFINE_string(
'code_url_prefix',
'https://github.com/tensorflow/models/blob/master/tensorflow_models/nlp',
'The url prefix for links to code.')
flags.DEFINE_string('code_url_prefix',
'https://github.com/tensorflow/models/blob/master/orbit',
'The url prefix for links to code.')
flags.DEFINE_bool('search_hints', True,
'Include metadata search hints in the generated files')
......@@ -47,51 +44,57 @@ flags.DEFINE_string('site_path', '/api_docs/python',
'Path prefix in the _toc.yaml')
PROJECT_SHORT_NAME = 'tfm.nlp'
PROJECT_FULL_NAME = 'TensorFlow Official Models - NLP Modeling Library'
PROJECT_SHORT_NAME = 'orbit'
PROJECT_FULL_NAME = 'Orbit'
def custom_filter(path, parent, children):
if len(path) == 1:
# Don't filter the contents of the top level `tfm.vision` package.
return children
else:
return public_api.explicit_package_contents_filter(path, parent, children)
def hide_module_model_and_layer_methods():
"""Hide methods and properties defined in the base classes of Keras layers.
We hide all methods and properties of the base classes, except:
- `__init__` is always documented.
- `call` is always documented, as it can carry important information for
complex layers.
"""
module_contents = list(tf.Module.__dict__.items())
model_contents = list(tf.keras.Model.__dict__.items())
layer_contents = list(tf.keras.layers.Layer.__dict__.items())
def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
project_full_name, search_hints):
"""Generates api docs for the tensorflow docs package."""
build_api_docs_lib.hide_module_model_and_layer_methods()
del tfm.nlp.layers.MultiHeadAttention
del tfm.nlp.layers.EinsumDense
for name, obj in module_contents + layer_contents + model_contents:
if name == '__init__':
# Always document __init__.
continue
if name == 'call':
# Always document `call`.
if hasattr(obj, doc_controls._FOR_SUBCLASS_IMPLEMENTERS): # pylint: disable=protected-access
delattr(obj, doc_controls._FOR_SUBCLASS_IMPLEMENTERS) # pylint: disable=protected-access
continue
# Otherwise, exclude from documentation.
if isinstance(obj, property):
obj = obj.fget
url_parts = code_url_prefix.strip('/').split('/')
url_parts = url_parts[:url_parts.index('tensorflow_models')]
url_parts.append('official')
if isinstance(obj, (staticmethod, classmethod)):
obj = obj.__func__
official_url_prefix = '/'.join(url_parts)
try:
doc_controls.do_not_doc_in_subclasses(obj)
except AttributeError:
pass
nlp_base_dir = pathlib.Path(tfm.nlp.__file__).parent
# The `layers` submodule (and others) are actually defined in the `official`
# package. Find the path to `official`.
official_base_dir = [
p for p in pathlib.Path(tfm.vision.layers.__file__).parents
if p.name == 'official'
][0]
def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
project_full_name, search_hints):
"""Generates api docs for the tensorflow docs package."""
doc_generator = generate_lib.DocGenerator(
root_title=project_full_name,
py_modules=[(project_short_name, tfm.nlp)],
base_dir=[nlp_base_dir, official_base_dir],
code_url_prefix=[
code_url_prefix,
official_url_prefix,
],
py_modules=[(project_short_name, orbit)],
code_url_prefix=code_url_prefix,
search_hints=search_hints,
site_path=site_path,
callbacks=[custom_filter],
callbacks=[public_api.explicit_package_contents_filter],
)
doc_generator.build(output_dir)
......
......@@ -17,8 +17,7 @@ r"""Tool to generate api_docs for tensorflow_models/official library.
Example:
$> pip install -U git+https://github.com/tensorflow/docs
$> python build_nlp_api_docs \
--output_dir=/tmp/api_docs
$> python build_nlp_api_docs.py --output_dir=/tmp/api_docs
"""
import pathlib
......@@ -26,11 +25,13 @@ import pathlib
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow_docs.api_generator import doc_controls
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
import tensorflow_models as tfm
from official.utils.docs import build_api_docs_lib
FLAGS = flags.FLAGS
......@@ -48,7 +49,43 @@ flags.DEFINE_string('site_path', '/api_docs/python',
PROJECT_SHORT_NAME = 'tfm'
PROJECT_FULL_NAME = 'TensorFlow Official Models - Modeling Library'
PROJECT_FULL_NAME = 'TensorFlow Modeling Library'
def hide_module_model_and_layer_methods():
"""Hide methods and properties defined in the base classes of Keras layers.
We hide all methods and properties of the base classes, except:
- `__init__` is always documented.
- `call` is always documented, as it can carry important information for
complex layers.
"""
module_contents = list(tf.Module.__dict__.items())
model_contents = list(tf.keras.Model.__dict__.items())
layer_contents = list(tf.keras.layers.Layer.__dict__.items())
for name, obj in module_contents + layer_contents + model_contents:
if name == '__init__':
# Always document __init__.
continue
if name == 'call':
# Always document `call`.
if hasattr(obj, doc_controls._FOR_SUBCLASS_IMPLEMENTERS): # pylint: disable=protected-access
delattr(obj, doc_controls._FOR_SUBCLASS_IMPLEMENTERS) # pylint: disable=protected-access
continue
# Otherwise, exclude from documentation.
if isinstance(obj, property):
obj = obj.fget
if isinstance(obj, (staticmethod, classmethod)):
obj = obj.__func__
try:
doc_controls.do_not_doc_in_subclasses(obj)
except AttributeError:
pass
def custom_filter(path, parent, children):
......@@ -62,7 +99,7 @@ def custom_filter(path, parent, children):
def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
project_full_name, search_hints):
"""Generates api docs for the tensorflow docs package."""
build_api_docs_lib.hide_module_model_and_layer_methods()
hide_module_model_and_layer_methods()
del tfm.nlp.layers.MultiHeadAttention
del tfm.nlp.layers.EinsumDense
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Tool to generate api_docs for tensorflow_models/official library.
Example:
$> pip install -U git+https://github.com/tensorflow/docs
$> python build_vision_api_docs \
--output_dir=/tmp/api_docs
"""
import pathlib
from absl import app
from absl import flags
from absl import logging
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import public_api
import tensorflow_models as tfm
import build_api_docs_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('output_dir', None, 'Where to write the resulting docs to.')
flags.DEFINE_string(
'code_url_prefix',
'https://github.com/tensorflow/models/blob/master/tensorflow_models/vision',
'The url prefix for links to code.')
flags.DEFINE_bool('search_hints', True,
'Include metadata search hints in the generated files')
flags.DEFINE_string('site_path', 'tfvision/api_docs/python',
'Path prefix in the _toc.yaml')
PROJECT_SHORT_NAME = 'tfm.vision'
PROJECT_FULL_NAME = 'TensorFlow Official Models - Vision Modeling Library'
def custom_filter(path, parent, children):
if len(path) == 1:
# Don't filter the contents of the top level `tfm.vision` package.
return children
else:
return public_api.explicit_package_contents_filter(path, parent, children)
def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
project_full_name, search_hints):
"""Generates api docs for the tensorflow docs package."""
build_api_docs_lib.hide_module_model_and_layer_methods()
url_parts = code_url_prefix.strip('/').split('/')
url_parts = url_parts[:url_parts.index('tensorflow_models')]
url_parts.append('official')
official_url_prefix = '/'.join(url_parts)
vision_base_dir = pathlib.Path(tfm.vision.__file__).parent
# The `layers` submodule (and others) are actually defined in the `official`
# package. Find the path to `official`.
official_base_dir = [
p for p in pathlib.Path(tfm.vision.layers.__file__).parents
if p.name == 'official'
][0]
doc_generator = generate_lib.DocGenerator(
root_title=project_full_name,
py_modules=[(project_short_name, tfm.vision)],
base_dir=[
vision_base_dir,
official_base_dir,
],
code_url_prefix=[code_url_prefix, official_url_prefix],
search_hints=search_hints,
site_path=site_path,
callbacks=[custom_filter],
)
doc_generator.build(output_dir)
logging.info('Output docs to: %s', output_dir)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
gen_api_docs(
code_url_prefix=FLAGS.code_url_prefix,
site_path=FLAGS.site_path,
output_dir=FLAGS.output_dir,
project_short_name=PROJECT_SHORT_NAME,
project_full_name=PROJECT_FULL_NAME,
search_hints=FLAGS.search_hints)
if __name__ == '__main__':
flags.mark_flag_as_required('output_dir')
app.run(main)
......@@ -53,7 +53,7 @@ depth, label smoothing and dropout.
| ResNet-RS-200 | 256x256 | 93.4 | 83.5 | 96.6 | [config](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnetrs200_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-200-i256.tar.gz) |
| ResNet-RS-270 | 256x256 | 130.1 | 83.6 | 96.6 | [config](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnetrs270_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-270-i256.tar.gz) |
| ResNet-RS-350 | 256x256 | 164.3 | 83.7 | 96.7 | [config](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnetrs350_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i256.tar.gz) |
| ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnetrs420_i256.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i320.tar.gz) |
| ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.com/tensorflow/models/blob/master/official/vision/configs/experiments/image_classification/imagenet_resnetrs350_i320.yaml) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/resnet-rs/resnet-rs-350-i320.tar.gz) |
#### Vision Transformer (ViT)
......@@ -100,7 +100,7 @@ evaluated on [COCO](https://cocodataset.org/) val2017.
| Backbone | Resolution | Epochs | FLOPs (B) | Params (M) | Box AP | Download |
| ------------ |:-------------:| -------:|--------------:|-----------:|-------:|---------:|
| R50-FPN | 640x640 | 12 | 97.0 | 34.0 | 34.3 | config|
| R50-FPN | 640x640 | 72 | 97.0 | 34.0 | 36.8 | config \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/retinanet/retinanet-resnet50fpn.tar.gz) |
| R50-FPN | 640x640 | 72 | 97.0 | 34.0 | 36.8 | [config](https://github.com/tensorflow/models/blob/master/official/vision/configs/retinanet.py#L187-L258) \| [ckpt](https://storage.cloud.google.com/tf_model_garden/vision/retinanet/retinanet-resnet50fpn.tar.gz) |
#### RetinaNet (Trained from scratch) with training features including:
......
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Vision package definition."""
# Lint as: python3
# pylint: disable=unused-import
from official.vision import configs
from official.vision import tasks
......@@ -17,12 +17,12 @@
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.configs import common
from official.vision.beta.projects.centernet.configs import backbones
from official.vision.beta.projects.centernet.modeling import centernet_model
from official.vision.beta.projects.centernet.modeling.backbones import hourglass
from official.vision.beta.projects.centernet.modeling.heads import centernet_head
from official.vision.beta.projects.centernet.modeling.layers import detection_generator
from official.vision.configs import common
class CenterNetTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -32,7 +32,7 @@ from official.vision.beta.projects.centernet.ops import loss_ops
from official.vision.beta.projects.centernet.ops import target_assigner
from official.vision.dataloaders import tf_example_decoder
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders.google import tf_example_label_map_decoder
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.modeling.backbones import factory
......
......@@ -59,6 +59,8 @@ class TfExampleDecoder(common.TfExampleDecoder):
"""A simple TF Example decoder config."""
# Setting this to true will enable decoding category_mask and instance_mask.
include_panoptic_masks: bool = True
panoptic_category_mask_key: str = 'image/panoptic/category_mask'
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'
@dataclasses.dataclass
......
......@@ -24,23 +24,30 @@ from official.vision.ops import preprocess_ops
class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
"""Tensorflow Example proto decoder."""
def __init__(self, regenerate_source_id,
mask_binarize_threshold, include_panoptic_masks):
def __init__(
self,
regenerate_source_id: bool,
mask_binarize_threshold: float,
include_panoptic_masks: bool,
panoptic_category_mask_key: str = 'image/panoptic/category_mask',
panoptic_instance_mask_key: str = 'image/panoptic/instance_mask'):
super(TfExampleDecoder, self).__init__(
include_mask=True,
regenerate_source_id=regenerate_source_id,
mask_binarize_threshold=None)
self._include_panoptic_masks = include_panoptic_masks
self._panoptic_category_mask_key = panoptic_category_mask_key
self._panoptic_instance_mask_key = panoptic_instance_mask_key
keys_to_features = {
'image/segmentation/class/encoded':
tf.io.FixedLenFeature((), tf.string, default_value='')}
if include_panoptic_masks:
keys_to_features.update({
'image/panoptic/category_mask':
panoptic_category_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/panoptic/instance_mask':
panoptic_instance_mask_key:
tf.io.FixedLenFeature((), tf.string, default_value='')})
self._segmentation_keys_to_features = keys_to_features
......@@ -56,10 +63,10 @@ class TfExampleDecoder(tf_example_decoder.TfExampleDecoder):
if self._include_panoptic_masks:
category_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/category_mask'],
parsed_tensors[self._panoptic_category_mask_key],
channels=1)
instance_mask = tf.io.decode_image(
parsed_tensors['image/panoptic/instance_mask'],
parsed_tensors[self._panoptic_instance_mask_key],
channels=1)
category_mask.set_shape([None, None, 1])
instance_mask.set_shape([None, None, 1])
......
......@@ -123,7 +123,9 @@ class PanopticMaskRCNNTask(maskrcnn.MaskRCNNTask):
decoder = panoptic_maskrcnn_input.TfExampleDecoder(
regenerate_source_id=decoder_cfg.regenerate_source_id,
mask_binarize_threshold=decoder_cfg.mask_binarize_threshold,
include_panoptic_masks=decoder_cfg.include_panoptic_masks)
include_panoptic_masks=decoder_cfg.include_panoptic_masks,
panoptic_category_mask_key=decoder_cfg.panoptic_category_mask_key,
panoptic_instance_mask_key=decoder_cfg.panoptic_instance_mask_key)
else:
raise ValueError('Unknown decoder type: {}!'.format(params.decoder.type))
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for YOLO export modules."""
from typing import List, Optional
import tensorflow as tf
from official.core import config_definitions as cfg
from official.vision import configs
from official.vision.beta.projects.yolo.configs.yolo import YoloTask
from official.vision.beta.projects.yolo.modeling import factory as yolo_factory
from official.vision.beta.projects.yolo.modeling.backbones import darknet # pylint: disable=unused-import
from official.vision.beta.projects.yolo.modeling.decoders import yolo_decoder # pylint: disable=unused-import
from official.vision.beta.projects.yolo.serving import model_fn as yolo_model_fn
from official.vision.dataloaders import classification_input
from official.vision.modeling import factory
from official.vision.serving import export_base_v2 as export_base
from official.vision.serving import export_utils
def create_classification_export_module(
params: cfg.ExperimentConfig,
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
"""Creates classification export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
input_image_size + [num_channels])
model = factory.build_classification_model(
input_specs=input_specs,
model_config=params.task.model,
l2_regularizer=None)
def preprocess_fn(inputs):
image_tensor = export_utils.parse_image(inputs, input_type,
input_image_size, num_channels)
# If input_type is `tflite`, do not apply image preprocessing.
if input_type == 'tflite':
return image_tensor
def preprocess_image_fn(inputs):
return classification_input.Parser.inference_fn(inputs, input_image_size,
num_channels)
images = tf.map_fn(
preprocess_image_fn,
elems=image_tensor,
fn_output_signature=tf.TensorSpec(
shape=input_image_size + [num_channels], dtype=tf.float32))
return images
def postprocess_fn(logits):
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
export_module = export_base.ExportModule(
params,
model=model,
input_signature=input_signature,
preprocessor=preprocess_fn,
postprocessor=postprocess_fn)
return export_module
def create_yolo_export_module(
params: cfg.ExperimentConfig,
input_type: str,
batch_size: int,
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
"""Creates YOLO export module."""
input_signature = export_utils.get_image_input_signatures(
input_type, batch_size, input_image_size, num_channels)
input_specs = tf.keras.layers.InputSpec(shape=[batch_size] +
input_image_size + [num_channels])
model, _ = yolo_factory.build_yolo(
input_specs=input_specs,
model_config=params.task.model,
l2_regularization=None)
def preprocess_fn(inputs):
image_tensor = export_utils.parse_image(inputs, input_type,
input_image_size, num_channels)
# If input_type is `tflite`, do not apply image preprocessing.
if input_type == 'tflite':
return image_tensor
def preprocess_image_fn(inputs):
image = tf.cast(inputs, dtype=tf.float32)
image = image / 255.
(image, image_info) = yolo_model_fn.letterbox(
image,
input_image_size,
letter_box=params.task.validation_data.parser.letter_box)
return image, image_info
images_spec = tf.TensorSpec(shape=input_image_size + [3], dtype=tf.float32)
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
preprocess_image_fn,
elems=image_tensor,
fn_output_signature=(images_spec, image_info_spec),
parallel_iterations=32))
return images, image_info
def inference_steps(inputs, model):
images, image_info = inputs
detection = model(images, training=False)
detection['bbox'] = yolo_model_fn.undo_info(
detection['bbox'],
detection['num_detections'],
image_info,
expand=False)
final_outputs = {
'detection_boxes': detection['bbox'],
'detection_scores': detection['confidence'],
'detection_classes': detection['classes'],
'num_detections': detection['num_detections']
}
return final_outputs
export_module = export_base.ExportModule(
params,
model=model,
input_signature=input_signature,
preprocessor=preprocess_fn,
inference_step=inference_steps)
return export_module
def get_export_module(params: cfg.ExperimentConfig,
input_type: str,
batch_size: Optional[int],
input_image_size: List[int],
num_channels: int = 3) -> export_base.ExportModule:
"""Factory for export modules."""
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = create_classification_export_module(params, input_type,
batch_size,
input_image_size,
num_channels)
elif isinstance(params.task, YoloTask):
export_module = create_yolo_export_module(params, input_type, batch_size,
input_image_size, num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
return export_module
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""YOLO model export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
CONFIG_FILE_PATH = XX
export_saved_model --export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--config_file=${CONFIG_FILE_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.projects.yolo.configs import yolo as cfg # pylint: disable=unused-import
from official.vision.beta.projects.yolo.serving import export_module_factory
from official.vision.beta.projects.yolo.tasks import yolo as task # pylint: disable=unused-import
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('experiment', 'scaled_yolo',
'experiment type, e.g. scaled_yolo')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer('batch_size', 1, 'The batch size.')
flags.DEFINE_string('input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
input_image_size = [int(x) for x in FLAGS.input_image_size.split(',')]
export_module = export_module_factory.get_export_module(
params=params,
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
num_channels=3)
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=input_image_size,
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_module=export_module)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""YOLO input and model functions for serving/inference."""
from typing import List, Tuple
import tensorflow as tf
from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.ops import box_ops
def letterbox(image: tf.Tensor,
desired_size: List[int],
letter_box: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
"""Letter box an image for image serving."""
with tf.name_scope('letter_box'):
image_size = tf.cast(preprocessing_ops.get_image_shape(image), tf.float32)
scaled_size = tf.cast(desired_size, image_size.dtype)
if letter_box:
scale = tf.minimum(scaled_size[0] / image_size[0],
scaled_size[1] / image_size[1])
scaled_size = tf.round(image_size * scale)
else:
scale = 1.0
# Computes 2D image_scale.
image_scale = scaled_size / image_size
image_offset = tf.cast((desired_size - scaled_size) * 0.5, tf.int32)
offset = (scaled_size - desired_size) * 0.5
scaled_image = tf.image.resize(
image, tf.cast(scaled_size, tf.int32), method='nearest')
output_image = tf.image.pad_to_bounding_box(scaled_image, image_offset[0],
image_offset[1],
desired_size[0],
desired_size[1])
image_info = tf.stack([
image_size,
tf.cast(desired_size, dtype=tf.float32), image_scale,
tf.cast(offset, tf.float32)
])
return output_image, image_info
def undo_info(boxes: tf.Tensor,
num_detections: int,
info: tf.Tensor,
expand: bool = True) -> tf.Tensor:
"""Clip and normalize boxes for serving."""
mask = tf.sequence_mask(num_detections, maxlen=tf.shape(boxes)[1])
boxes = tf.cast(tf.expand_dims(mask, axis=-1), boxes.dtype) * boxes
if expand:
info = tf.cast(tf.expand_dims(info, axis=0), boxes.dtype)
inshape = tf.expand_dims(info[:, 1, :], axis=1)
ogshape = tf.expand_dims(info[:, 0, :], axis=1)
scale = tf.expand_dims(info[:, 2, :], axis=1)
offset = tf.expand_dims(info[:, 3, :], axis=1)
boxes = box_ops.denormalize_boxes(boxes, inshape)
boxes += tf.tile(offset, [1, 1, 2])
boxes /= tf.tile(scale, [1, 1, 2])
boxes = box_ops.clip_boxes(boxes, ogshape)
boxes = box_ops.normalize_boxes(boxes, ogshape)
return boxes
......@@ -36,7 +36,7 @@ from official.vision.beta.projects.yolo.ops import mosaic
from official.vision.beta.projects.yolo.ops import preprocessing_ops
from official.vision.beta.projects.yolo.tasks import task_utils
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders.google import tf_example_label_map_decoder
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.ops import box_ops
......
......@@ -19,12 +19,16 @@ task:
is_training: true
global_batch_size: 2048
dtype: 'float16'
# Autotuning the prefetch buffer size causes OOMs, so set it to a reasonable
# static value: 32. See b/218880025.
prefetch_buffer_size: 32
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 2048
dtype: 'float16'
drop_remainder: false
prefetch_buffer_size: 32
trainer:
train_steps: 56160
validation_steps: 25
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment