Commit 9a052f52 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Refactor decoder factory to allow registering other decoders.

PiperOrigin-RevId: 383944185
parent c6afac2c
...@@ -12,49 +12,105 @@ ...@@ -12,49 +12,105 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3 """Decoder registers and factory method.
"""factory method."""
One can register a new decoder model by the following two steps:
1 Import the factory and register the build in the decoder file.
2 Import the decoder class and add a build in __init__.py.
```
# my_decoder.py
from modeling.decoders import factory
class MyDecoder():
...
@factory.register_decoder_builder('my_decoder')
def build_my_decoder():
return MyDecoder()
# decoders/__init__.py adds import
from modeling.decoders.my_decoder import MyDecoder
```
If one wants the MyDecoder class to be used only by those binary
then don't imported the decoder module in decoders/__init__.py, but import it
in place that uses it.
"""
from typing import Union, Mapping, Optional
# Import libraries
from typing import Mapping
import tensorflow as tf import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import decoders from official.core import registry
from official.modeling import hyperparams
_REGISTERED_DECODER_CLS = {}
def register_decoder_builder(key: str):
"""Decorates a builder of decoder class.
The builder should be a Callable (a class or a function).
This decorator supports registration of decoder builder as follows:
```
class MyDecoder(tf.keras.Model):
pass
@register_decoder_builder('mydecoder')
def builder(input_specs, config, l2_reg):
return MyDecoder(...)
# Builds a MyDecoder object.
my_decoder = build_decoder_3d(input_specs, config, l2_reg)
```
Args:
key: A `str` of key to look up the builder.
Returns:
A callable for using as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_DECODER_CLS, key)
@register_decoder_builder('identity')
def build_identity(
input_specs: Optional[Mapping[str, tf.TensorShape]] = None,
model_config: Optional[hyperparams.Config] = None,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None) -> None:
del input_specs, model_config, l2_regularizer # Unused by identity decoder.
return None
def build_decoder( def build_decoder(
input_specs: Mapping[str, tf.TensorShape], input_specs: Mapping[str, tf.TensorShape],
model_config, model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model: l2_regularizer: tf.keras.regularizers.Regularizer = None,
**kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]:
"""Builds decoder from a config. """Builds decoder from a config.
Args: Args:
input_specs: `dict` input specifications. A dictionary consists of input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone. {level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config. model_config: A `OneOfConfig` of model config.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None. l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None.
**kwargs: Additional keyword args to be passed to decoder builder.
Returns: Returns:
A tf.keras.Model instance of the decoder. An instance of the decoder.
""" """
decoder_type = model_config.decoder.type decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS,
decoder_cfg = model_config.decoder.get() model_config.decoder.type)
norm_activation_config = model_config.norm_activation
return decoder_builder(
if decoder_type == 'identity':
decoder = None
elif decoder_type == 'unet_3d_decoder':
decoder = decoders.UNet3DDecoder(
model_id=decoder_cfg.model_id,
input_specs=input_specs, input_specs=input_specs,
pool_size=decoder_cfg.pool_size, model_config=model_config,
kernel_regularizer=l2_regularizer, l2_regularizer=l2_regularizer,
activation=norm_activation_config.activation, **kwargs)
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
use_sync_bn=norm_activation_config.use_sync_bn,
use_batch_normalization=decoder_cfg.use_batch_normalization,
use_deconvolution=decoder_cfg.use_deconvolution)
else:
raise ValueError('Decoder {!r} not implement'.format(decoder_type))
return decoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory functions."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.vision.beta.projects.volumetric_models.configs import decoders as decoders_cfg
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_exp
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling.decoders import factory
class FactoryTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(combinations.combine(model_id=[2, 3],))
def test_unet_3d_decoder_creation(self, model_id):
"""Test creation of UNet 3D decoder."""
# Create test input for decoders based on input model_id.
input_specs = {}
for level in range(model_id):
input_specs[str(level + 1)] = tf.TensorShape(
[1, 128 // (2**level), 128 // (2**level), 128 // (2**level), 1])
network = decoders.UNet3DDecoder(
model_id=model_id,
input_specs=input_specs,
use_sync_bn=True,
use_batch_normalization=True,
use_deconvolution=True)
model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D()
model_config.num_classes = 2
model_config.num_channels = 1
model_config.input_size = [None, None, None]
model_config.decoder = decoders_cfg.Decoder(
type='unet_3d_decoder',
unet_3d_decoder=decoders_cfg.UNet3DDecoder(model_id=model_id))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
print(network_config)
print(factory_network_config)
self.assertEqual(network_config, factory_network_config)
def test_identity_creation(self):
"""Test creation of identity decoder."""
model_config = semantic_segmentation_3d_exp.SemanticSegmentationModel3D()
model_config.num_classes = 2
model_config.num_channels = 3
model_config.input_size = [None, None, None]
model_config.decoder = decoders_cfg.Decoder(
type='identity', identity=decoders_cfg.Identity())
factory_network = factory.build_decoder(
input_specs=None, model_config=model_config)
self.assertIsNone(factory_network)
if __name__ == '__main__':
tf.test.main()
...@@ -19,10 +19,13 @@ Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse ...@@ -19,10 +19,13 @@ Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650. Annotation. arXiv:1606.06650.
""" """
from typing import Any, Sequence, Dict, Mapping from typing import Any, Dict, Mapping, Optional, Sequence
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d
from official.vision.beta.projects.volumetric_models.modeling.decoders import factory
layers = tf.keras.layers layers = tf.keras.layers
...@@ -152,3 +155,39 @@ class UNet3DDecoder(tf.keras.Model): ...@@ -152,3 +155,39 @@ class UNet3DDecoder(tf.keras.Model):
def output_specs(self) -> Mapping[str, tf.TensorShape]: def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
@factory.register_decoder_builder('unet_3d_decoder')
def build_unet_3d_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds UNet3D decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the UNet3D decoder.
"""
decoder_type = model_config.decoder.type
decoder_cfg = model_config.decoder.get()
assert decoder_type == 'unet_3d_decoder', (f'Inconsistent decoder type '
f'{decoder_type}')
norm_activation_config = model_config.norm_activation
return UNet3DDecoder(
model_id=decoder_cfg.model_id,
input_specs=input_specs,
pool_size=decoder_cfg.pool_size,
kernel_regularizer=l2_regularizer,
activation=norm_activation_config.activation,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
use_sync_bn=norm_activation_config.use_sync_bn,
use_batch_normalization=decoder_cfg.use_batch_normalization,
use_deconvolution=decoder_cfg.use_deconvolution)
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory from official.vision.beta.projects.volumetric_models.modeling import factory
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d # pylint: disable=unused-import
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
# pylint: disable=unused-import # pylint: disable=unused-import
from official.common import registry_imports from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.modeling import backbones from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d
...@@ -18,8 +18,10 @@ from typing import Mapping ...@@ -18,8 +18,10 @@ from typing import Mapping
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory from official.vision.beta.projects.volumetric_models.modeling import factory
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d # pylint: disable=unused-import
from official.vision.beta.serving import export_base from official.vision.beta.serving import export_base
......
...@@ -20,9 +20,11 @@ from absl.testing import parameterized ...@@ -20,9 +20,11 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
# pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg # pylint: disable=unused-import from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d # pylint: disable=unused-import from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
......
...@@ -28,7 +28,8 @@ from official.core import exp_factory ...@@ -28,7 +28,8 @@ from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.dataloaders import tfexample_utils from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment