Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory methods to build models."""
# Import libraries
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
def build_segmentation_model_3d(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
backbone = backbone_factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
head_config = model_config.head
head = segmentation_heads_3d.SegmentationHead3D(
num_classes=model_config.num_classes,
level=head_config.level,
num_convs=head_config.num_convs,
num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
use_batch_normalization=head_config.use_batch_normalization,
kernel_regularizer=l2_regularizer,
output_logits=head_config.output_logits)
model = segmentation_model.SegmentationModel(backbone, decoder, head)
return model
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory.py."""
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(((128, 128, 128), 5e-5, True),
((64, 64, 64), None, False))
def test_unet3d_builder(self, input_size, weight_decay, use_bn):
num_classes = 3
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], input_size[2], 3])
model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
model_config.head.use_batch_normalization = use_bn
l2_regularizer = (
tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
model = factory.build_segmentation_model_3d(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
self.assertIsInstance(
model, tf.keras.Model,
'Output should be a tf.keras.Model instance but got %s' % type(model))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Segmentation heads."""
from typing import Any, Union, Sequence, Mapping
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class SegmentationHead3D(tf.keras.layers.Layer):
"""Segmentation head for 3D input."""
def __init__(self,
num_classes: int,
level: Union[int, str],
num_convs: int = 2,
num_filters: int = 256,
upsample_factor: int = 1,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_batch_normalization: bool = False,
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_logits: bool = True,
**kwargs):
"""Initialize params to build segmentation head.
Args:
num_classes: `int` number of mask classification categories. The number of
classes does not include background class.
level: `int` or `str`, level to use to build segmentation head.
num_convs: `int` number of stacked convolution before the last prediction
layer.
num_filters: `int` number to specify the number of filters used. Default
is 256.
upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied.
activation: `string`, indicating which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: `bool`, whether to use synchronized batch normalization
across different replicas.
norm_momentum: `float`, the momentum parameter of the normalization
layers.
norm_epsilon: `float`, the epsilon parameter of the normalization layers.
use_batch_normalization: A bool of whether to use batch normalization or
not.
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
output_logits: A `bool` of whether to output logits or not. Default
is True. If set to False, output softmax.
**kwargs: other keyword arguments passed to Layer.
"""
super(SegmentationHead3D, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'upsample_factor': upsample_factor,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'use_batch_normalization': use_batch_normalization,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'output_logits': output_logits
}
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape: Union[tf.TensorShape, Sequence[tf.TensorShape]]):
"""Creates the variables of the segmentation head."""
conv_op = tf.keras.layers.Conv3D
conv_kwargs = {
'kernel_size': (3, 3, 3),
'padding': 'same',
'use_bias': False,
'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
}
final_kernel_size = (1, 1, 1)
bn_op = (
tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn'] else
tf.keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
# Segmentation head layers.
self._convs = []
self._norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append(
conv_op(
name=conv_name,
filters=self._config_dict['num_filters'],
**conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i)
if self._config_dict['use_batch_normalization']:
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
self._classifier = conv_op(
name='segmentation_output',
filters=self._config_dict['num_classes'],
kernel_size=final_kernel_size,
padding='valid',
activation=None,
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
super(SegmentationHead3D, self).build(input_shape)
def call(self, backbone_output: Mapping[str, tf.Tensor],
decoder_output: Mapping[str, tf.Tensor]) -> tf.Tensor:
"""Forward pass of the segmentation head.
Args:
backbone_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
Returns:
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
"""
x = decoder_output[str(self._config_dict['level'])]
for i, conv in enumerate(self._convs):
x = conv(x)
if self._norms:
x = self._norms[i](x)
x = self._activation(x)
x = tf.keras.layers.UpSampling3D(size=self._config_dict['upsample_factor'])(
x)
x = self._classifier(x)
return x if self._config_dict['output_logits'] else tf.keras.layers.Softmax(
dtype='float32')(
x)
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config: Mapping[str, Any]):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for segmentation_heads.py."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 0, True),
(2, 1, False),
)
def test_forward(self, level, num_convs, use_bn):
head = segmentation_heads_3d.SegmentationHead3D(
num_classes=10,
level=level,
num_convs=num_convs,
use_batch_normalization=use_bn)
backbone_features = {
'1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16),
}
decoder_features = {
'1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16),
}
logits = head(backbone_features, decoder_features)
if str(level) in decoder_features:
self.assertAllEqual(logits.numpy().shape, [
2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2],
decoder_features[str(level)].shape[3], 10
])
def test_serialize_deserialize(self):
head = segmentation_heads_3d.SegmentationHead3D(num_classes=10, level=3)
config = head.get_config()
new_head = segmentation_heads_3d.SegmentationHead3D.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for neural networks."""
from typing import Sequence, Union
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class BasicBlock3DVolume(tf.keras.layers.Layer):
"""A basic 3d convolution block."""
def __init__(self,
filters: Union[int, Sequence[int]],
strides: Union[int, Sequence[int]],
kernel_size: Union[int, Sequence[int]],
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_batch_normalization: bool = False,
**kwargs):
"""Creates a basic 3d convolution block applying one or more convolutions.
Args:
filters: A list of `int` numbers or an `int` number of filters. Given an
`int` input, a single convolution is applied; otherwise a series of
convolutions are applied.
strides: An integer or tuple/list of 3 integers, specifying the strides of
the convolution along each spatial dimension. Can be a single integer to
specify the same value for all spatial dimensions.
kernel_size: An integer or tuple/list of 3 integers, specifying the depth,
height and width of the 3D convolution window. Can be a single integer
to specify the same value for all spatial dimensions.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
use_batch_normalization: Wheher to use batch normalizaion or not.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
if isinstance(filters, int):
self._filters = [filters]
else:
self._filters = filters
self._strides = strides
self._kernel_size = kernel_size
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._use_batch_normalization = use_batch_normalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape: tf.TensorShape):
"""Builds the basic 3d convolution block."""
self._convs = []
self._norms = []
for filters in self._filters:
self._convs.append(
tf.keras.layers.Conv3D(
filters=filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
data_format=tf.keras.backend.image_data_format(),
activation=None))
self._norms.append(
self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon))
super(BasicBlock3DVolume, self).build(input_shape)
def get_config(self):
"""Returns the config of the basic 3d convolution block."""
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'use_batch_normalization': self._use_batch_normalization
}
base_config = super(BasicBlock3DVolume, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
"""Runs forward pass on the input tensor."""
x = inputs
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
if self._use_batch_normalization:
x = norm(x)
x = self._activation_fn(x)
return x
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResidualBlock3DVolume(tf.keras.layers.Layer):
"""A residual 3d block."""
def __init__(self,
filters,
strides,
use_projection=False,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A residual 3d block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters,
out_filters=self._filters,
se_ratio=self._se_ratio,
use_3d_input=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
super(ResidualBlock3DVolume, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResidualBlock3DVolume, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
return self._activation_fn(x + shortcut)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock3DVolume(tf.keras.layers.Layer):
"""A standard bottleneck block."""
def __init__(self,
filters,
strides,
dilation_rate=1,
use_projection=False,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A standard bottleneck 3d block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int` dilation_rate of convolutions. Default to 1.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._filters = filters
self._strides = strides
self._dilation_rate = dilation_rate
self._use_projection = use_projection
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv3D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
dilation_rate=self._dilation_rate,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv3 = tf.keras.layers.Conv3D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters * 4,
out_filters=self._filters * 4,
se_ratio=self._se_ratio,
use_3d_input=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
super(BottleneckBlock3DVolume, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(BottleneckBlock3DVolume, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._conv3(x)
x = self._norm3(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
return self._activation_fn(x + shortcut)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D volumeric convoluion blocks."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d
class NNBlocks3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((128, 128, 32, 1), (256, 256, 16, 2))
def test_bottleneck_block_3d_volume_creation(self, spatial_size, volume_size,
filters, strides):
inputs = tf.keras.Input(
shape=(spatial_size, spatial_size, volume_size, filters * 4),
batch_size=1)
block = nn_blocks_3d.BottleneckBlock3DVolume(
filters=filters, strides=strides, use_projection=True)
features = block(inputs)
self.assertAllEqual([
1, spatial_size // strides, spatial_size // strides,
volume_size // strides, filters * 4
], features.shape.as_list())
@parameterized.parameters((128, 128, 32, 1), (256, 256, 64, 2))
def test_residual_block_3d_volume_creation(self, spatial_size, volume_size,
filters, strides):
inputs = tf.keras.Input(
shape=(spatial_size, spatial_size, volume_size, filters), batch_size=1)
block = nn_blocks_3d.ResidualBlock3DVolume(
filters=filters, strides=strides, use_projection=True)
features = block(inputs)
self.assertAllEqual([
1, spatial_size // strides, spatial_size // strides,
volume_size // strides, filters
], features.shape.as_list())
@parameterized.parameters((128, 128, 64, 1, 3), (256, 256, 128, 2, 1))
def test_basic_block_3d_volume_creation(self, spatial_size, volume_size,
filters, strides, kernel_size):
inputs = tf.keras.Input(
shape=(spatial_size, spatial_size, volume_size, filters), batch_size=1)
block = nn_blocks_3d.BasicBlock3DVolume(
filters=filters, strides=strides, kernel_size=kernel_size)
features = block(inputs)
self.assertAllEqual([
1, spatial_size // strides, spatial_size // strides,
volume_size // strides, filters
], features.shape.as_list())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for segmentation network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
([32, 32], 4),
([64, 64], 4),
([64, 64], 2),
([128, 64], 2),
)
def test_segmentation_network_unet3d_creation(self, input_size, depth):
"""Test for creation of a segmentation network."""
num_classes = 2
inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1], 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.UNet3D(model_id=depth)
decoder = decoders.UNet3DDecoder(
model_id=depth, input_specs=backbone.output_specs)
head = segmentation_heads_3d.SegmentationHead3D(
num_classes, level=1, num_convs=0)
model = segmentation_model.SegmentationModel(
backbone=backbone, decoder=decoder, head=head)
logits = model(inputs)
self.assertAllEqual(
[2, input_size[0], input_size[0], input_size[1], num_classes],
logits.numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
num_classes = 3
backbone = backbones.UNet3D(model_id=4)
decoder = decoders.UNet3DDecoder(
model_id=4, input_specs=backbone.output_specs)
head = segmentation_heads_3d.SegmentationHead3D(
num_classes, level=1, num_convs=0)
model = segmentation_model.SegmentationModel(
backbone=backbone, decoder=decoder, head=head)
config = model.get_config()
new_model = segmentation_model.SegmentationModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as semantic_segmentation_3d_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Volumetric model export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=1 \
--input_image_size=128,128,128 \
--num_channels=1
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_list(
'input_image_size', None,
'The comma-separated string of three integers representing the '
'height, width and depth of the input to the model.')
flags.DEFINE_integer('num_channels', 1,
'The number of channels of input image.')
flags.register_validator(
'input_image_size',
lambda value: value is not None and len(value) == 3,
message='--input_image_size must be comma-separated string of three '
'integers representing the height, width and depth of the input to '
'the model.')
def main(_):
flags.mark_flag_as_required('export_dir')
flags.mark_flag_as_required('checkpoint_path')
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
input_image_size = FLAGS.input_image_size
export_module = semantic_segmentation_3d.SegmentationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
num_channels=FLAGS.num_channels)
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=input_image_size,
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
num_channels=FLAGS.num_channels,
export_module=export_module,
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
if __name__ == '__main__':
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""3D semantic segmentation input and model functions for serving/inference."""
from typing import Mapping
import tensorflow as tf
# pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling import factory
from official.vision.beta.serving import export_base
class SegmentationModule(export_base.ExportModule):
"""Segmentation Module."""
def _build_model(self) -> tf.keras.Model:
"""Builds and returns a segmentation model."""
num_channels = self.params.task.model.num_channels
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [num_channels])
return factory.build_segmentation_model_3d(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def serve(
self, images: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Casts an image tensor to float and runs inference.
Args:
images: A uint8 tf.Tensor of shape [batch_size, None, None, None,
num_channels].
Returns:
A dictionary holding segmentation outputs.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
outputs = self.inference_step(images)
output_key = 'logits' if self.params.task.model.head.output_logits else 'probs'
return {output_key: outputs}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for semantic_segmentation_3d export lib."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._num_channels = 2
self._input_image_size = [32, 32, 32]
self._params = exp_factory.get_exp_config('seg_unet3d_test')
input_shape = self._input_image_size + [self._num_channels]
self._image_array = np.zeros(shape=input_shape, dtype=np.uint8)
def _get_segmentation_module(self):
return semantic_segmentation_3d.SegmentationModule(
self._params,
batch_size=1,
input_image_size=self._input_image_size,
num_channels=self._num_channels)
def _export_from_module(self, module, input_type: str, save_directory: str):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
image_tensor = tf.convert_to_tensor(self._image_array, dtype=tf.uint8)
return tf.expand_dims(image_tensor, axis=0)
if input_type == 'image_bytes':
return [self._image_array.tostring()]
if input_type == 'tf_example':
encoded_image = self._image_array.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_image])),
})).SerializeToString()
return [example]
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type: str = 'image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module()
self._export_from_module(module, input_type, tmp_dir)
# Check if model is successfully exported.
self.assertTrue(tf.io.gfile.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
tf.io.gfile.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
tf.io.gfile.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
# Get inference signature from loaded SavedModel.
imported = tf.saved_model.load(tmp_dir)
segmentation_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
image_tensor = self._get_dummy_input(input_type='image_tensor')
# Perform inference using loaded SavedModel and model instance and check if
# outputs equal.
expected_output = module.model(image_tensor, training=False)
out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image segmentation task definition."""
from typing import Any, Dict, Mapping, Optional, Sequence, Union
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.losses import segmentation_losses
from official.vision.beta.projects.volumetric_models.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentation3DTask)
class SemanticSegmentation3DTask(base_task.Task):
"""A task for semantic segmentation."""
def build_model(self) -> tf.keras.Model:
"""Builds segmentation model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size +
[self.task_config.model.num_channels],
dtype=self.task_config.train_data.dtype)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
model = factory.build_segmentation_model_3d(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
# Create a dummy input and call model instance to initialize the model. This
# is needed when launching multiple experiments using the same model
# directory. Since there is already a trained model, forward pass will not
# run and the model will never be built. This is only done when spatial
# partitioning is not enabled; otherwise it will fail with OOM due to
# extremely large input.
if (not self.task_config.train_input_partition_dims) and (
not self.task_config.eval_input_partition_dims):
dummy_input = tf.random.uniform(shape=[1] + list(input_specs.shape[1:]))
_ = model(dummy_input)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
"""Builds classification input."""
decoder = segmentation_input_3d.Decoder(
image_field_key=params.image_field_key,
label_field_key=params.label_field_key)
parser = segmentation_input_3d.Parser(
input_size=params.input_size,
num_classes=params.num_classes,
num_channels=params.num_channels,
image_field_key=params.image_field_key,
label_field_key=params.label_field_key,
dtype=params.dtype,
label_dtype=params.label_dtype)
reader = input_reader.InputReader(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses=None) -> tf.Tensor:
"""Segmentation loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
segmentation_loss_fn = segmentation_losses.SegmentationLossDiceScore(
metric_type='adaptive')
total_loss = segmentation_loss_fn(model_outputs, labels)
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self,
training: bool = True) -> Sequence[tf.keras.metrics.Metric]:
"""Gets streaming metrics for training/validation."""
metrics = []
num_classes = self.task_config.model.num_classes
if training:
metrics.extend([
tf.keras.metrics.CategoricalAccuracy(
name='train_categorical_accuracy', dtype=tf.float32)
])
else:
self.metrics = [
segmentation_metrics.DiceScore(
num_classes=num_classes,
metric_type='generalized',
per_class_metric=self.task_config.evaluation
.report_per_class_metric,
name='val_generalized_dice',
dtype=tf.float32)
]
return metrics
def train_step(
self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[Sequence[tf.keras.metrics.Metric]] = None
) -> Dict[Any, Any]:
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features = strategy.experimental_split_to_logical_devices(
features, input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.model.head.output_logits:
outputs = tf.nn.softmax(outputs)
# Computes per-replica loss.
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
# Compute all metrics within strategy scope for training.
if metrics:
labels = tf.cast(labels, tf.float32)
outputs = tf.cast(outputs, tf.float32)
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def validation_step(
self,
inputs,
model: tf.keras.Model,
metrics: Optional[Sequence[tf.keras.metrics.Metric]] = None
) -> Dict[Any, Any]:
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features = strategy.experimental_split_to_logical_devices(
features, input_partition_dims)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.model.head.output_logits:
outputs = tf.nn.softmax(outputs)
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
# Compute dice score metrics on CPU.
for metric in self.metrics:
labels = tf.cast(labels, tf.float32)
outputs = tf.cast(outputs, tf.float32)
logs.update({metric.name: (labels, outputs)})
return logs
def inference_step(self, inputs, model: tf.keras.Model) -> tf.Tensor:
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(
self,
state: Optional[Sequence[Union[segmentation_metrics.DiceScore,
tf.keras.metrics.Metric]]] = None,
step_outputs: Optional[Mapping[str, Any]] = None
) -> Sequence[tf.keras.metrics.Metric]:
"""Aggregates statistics to compute metrics over training.
Args:
state: A sequence of tf.keras.metrics.Metric objects. Each element records
a metric.
step_outputs: A dictionary of [metric_name, (labels, output)] from a step.
Returns:
An updated sequence of tf.keras.metrics.Metric objects.
"""
if state is None:
for metric in self.metrics:
metric.reset_states()
state = self.metrics
for metric in self.metrics:
labels = step_outputs[metric.name][0]
predictions = step_outputs[metric.name][1]
# If `step_output` is distributed, it contains a tuple of Tensors instead
# of a single Tensor, so we need to concatenate them along the batch
# dimension in this case to have a single Tensor.
if isinstance(labels, tuple):
labels = tf.concat(list(labels), axis=0)
if isinstance(predictions, tuple):
predictions = tf.concat(list(predictions), axis=0)
labels = tf.cast(labels, tf.float32)
predictions = tf.cast(predictions, tf.float32)
metric.update_state(labels, predictions)
return state
def reduce_aggregated_logs(
self,
aggregated_logs: Optional[Mapping[str, Any]] = None,
global_step: Optional[tf.Tensor] = None) -> Mapping[str, float]:
"""Reduces logs to obtain per-class metrics if needed.
Args:
aggregated_logs: An optional dictionary containing aggregated logs.
global_step: An optional `tf.Tensor` of current global training steps.
Returns:
The reduced logs containing per-class metrics and overall metrics.
Raises:
ValueError: If `self.metrics` does not contain exactly 1 metric object.
"""
result = {}
if len(self.metrics) != 1:
raise ValueError('Exact one metric must be present, but {0} are '
'present.'.format(len(self.metrics)))
metric = self.metrics[0].result().numpy()
if self.task_config.evaluation.report_per_class_metric:
for i, metric_val in enumerate(metric):
metric_name = self.metrics[0].name + '/class_{0}'.format(
i - 1) if i > 0 else self.metrics[0].name
result.update({metric_name: metric_val})
return result
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for semantic segmentation task."""
# pylint: disable=unused-import
import functools
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import optimization
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.create_3d_image_test_example(
image_height=32, image_width=32, image_volume=32, image_channel=2)
for _ in range(20)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
@parameterized.parameters(('seg_unet3d_test',))
def test_task(self, config_name):
config = exp_factory.get_exp_config(config_name)
config.task.train_data.input_path = self._data_path
config.task.train_data.global_batch_size = 4
config.task.train_data.shuffle_buffer_size = 4
config.task.validation_data.input_path = self._data_path
config.task.validation_data.shuffle_buffer_size = 4
config.task.evaluation.report_per_class_metric = True
task = img_seg_task.SemanticSegmentation3DTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
# Check if training loss is produced.
self.assertIn('loss', logs)
# Obtain distributed outputs.
distributed_outputs = strategy.run(
functools.partial(
task.validation_step,
model=model,
metrics=task.build_metrics(training=False)),
args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
# Check if validation loss is produced.
self.assertIn('loss', outputs)
# Check if state is updated.
state = task.aggregate_logs(state=None, step_outputs=outputs)
self.assertLen(state, 1)
self.assertIsInstance(state[0], segmentation_metrics.DiceScore)
# Check if all metrics are produced.
result = task.reduce_aggregated_logs(aggregated_logs={}, global_step=1)
self.assertIn('val_generalized_dice', result)
self.assertIn('val_generalized_dice/class_0', result)
self.assertIn('val_generalized_dice/class_1', result)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
import gin # pylint: disable=unused-import
from official.common import flags as tfm_flags
from official.vision.beta import train
from official.vision.beta.projects.volumetric_models import registry_imports # pylint: disable=unused-import
def main(_):
train.main(_)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train."""
import json
import os
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models import train as train_lib
FLAGS = flags.FLAGS
class TrainTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.create_3d_image_test_example(
image_height=32, image_width=32, image_volume=32, image_channel=2)
for _ in range(2)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_run(self):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
FLAGS.model_dir = self._model_dir
FLAGS.experiment = 'seg_unet3d_test'
logging.info('Test pipeline correctness.')
params_override = json.dumps({
'runtime': {
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 1,
'validation_steps': 1,
},
'task': {
'model': {
'backbone': {
'unet_3d': {
'model_id': 4,
},
},
'decoder': {
'unet_3d_decoder': {
'model_id': 4,
},
},
},
'train_data': {
'input_path': self._data_path,
'file_type': 'tfrecord',
'global_batch_size': 2,
},
'validation_data': {
'input_path': self._data_path,
'file_type': 'tfrecord',
'global_batch_size': 2,
}
}
})
FLAGS.params_override = params_override
train_lib.main('unused_args')
FLAGS.mode = 'eval'
with train_lib.gin.unlock_config():
train_lib.main('unused_args')
flagsaver.restore_flag_values(saved_flag_values)
if __name__ == '__main__':
tf.test.main()
......@@ -20,6 +20,7 @@ import tensorflow as tf
from official.vision.beta import configs
from official.vision.beta.modeling import factory
from official.vision.beta.ops import anchor
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.serving import export_base
......@@ -130,6 +131,28 @@ class DetectionModule(export_base.ExportModule):
training=False)
if self.params.task.model.detection_generator.apply_nms:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
detection_boxes = (
detections['detection_boxes'] /
tf.tile(image_info[:, 2:3, :], [1, 1, 2]))
detections['detection_boxes'] = box_ops.normalize_boxes(
detection_boxes, image_info[:, 0:1, :])
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if export_config.cast_num_detections_to_float:
detections['num_detections'] = tf.cast(
detections['num_detections'], dtype=tf.float32)
if export_config.cast_detection_classes_to_float:
detections['detection_classes'] = tf.cast(
detections['detection_classes'], dtype=tf.float32)
final_outputs = {
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
......@@ -139,9 +162,7 @@ class DetectionModule(export_base.ExportModule):
else:
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores'],
'cls_outputs': detections['cls_outputs'],
'box_outputs': detections['box_outputs']
'decoded_box_scores': detections['decoded_box_scores']
}
if 'detection_masks' in detections.keys():
......
......@@ -73,6 +73,10 @@ flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
flags.DEFINE_string('export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
flags.DEFINE_string('export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
def main(_):
......@@ -95,8 +99,8 @@ def main(_):
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
export_checkpoint_subdir=FLAGS.export_checkpoint_subdir,
export_saved_model_subdir=FLAGS.export_saved_model_subdir)
if __name__ == '__main__':
......
......@@ -27,6 +27,7 @@ from official.vision.beta import configs
from official.vision.beta.serving import detection
from official.vision.beta.serving import image_classification
from official.vision.beta.serving import semantic_segmentation
from official.vision.beta.serving import video_classification
def export_inference_graph(
......@@ -68,7 +69,7 @@ def export_inference_graph(
output_checkpoint_directory = os.path.join(
export_dir, export_checkpoint_subdir)
else:
output_checkpoint_directory = export_dir
output_checkpoint_directory = None
if export_saved_model_subdir:
output_saved_model_directory = os.path.join(
......@@ -99,6 +100,13 @@ def export_inference_graph(
batch_size=batch_size,
input_image_size=input_image_size,
num_channels=num_channels)
elif isinstance(params.task,
configs.video_classification.VideoClassificationTask):
export_module = video_classification.VideoClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size,
num_channels=num_channels)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
......@@ -111,6 +119,7 @@ def export_inference_graph(
timestamped=False,
save_options=save_options)
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
if output_checkpoint_directory:
ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir)
......@@ -13,7 +13,7 @@
# limitations under the License.
# Lint as: python3
"""Detection input and model functions for serving/inference."""
"""Image classification input and model functions for serving/inference."""
import tensorflow as tf
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification input and model functions for serving/inference."""
from typing import Mapping, Dict, Text
import tensorflow as tf
from official.vision.beta.dataloaders import video_input
from official.vision.beta.serving import export_base
from official.vision.beta.tasks import video_classification
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class VideoClassificationModule(export_base.ExportModule):
"""Video classification Module."""
def _build_model(self):
input_params = self.params.task.train_data
self._num_frames = input_params.feature_shape[0]
self._stride = input_params.temporal_stride
self._min_resize = input_params.min_image_size
self._crop_size = input_params.feature_shape[1]
self._output_audio = input_params.output_audio
task = video_classification.VideoClassificationTask(self.params.task)
return task.build_model()
def _decode_tf_example(self, encoded_inputs: tf.Tensor):
sequence_description = {
# Each image is a string encoding JPEG.
video_input.IMAGE_KEY:
tf.io.FixedLenSequenceFeature((), tf.string),
}
if self._output_audio:
sequence_description[self._params.task.validation_data.audio_feature] = (
tf.io.VarLenFeature(dtype=tf.float32))
_, decoded_tensors = tf.io.parse_single_sequence_example(
encoded_inputs, {}, sequence_description)
for key, value in decoded_tensors.items():
if isinstance(value, tf.SparseTensor):
decoded_tensors[key] = tf.sparse.to_dense(value)
return decoded_tensors
def _preprocess_image(self, image):
image = video_input.process_image(
image=image,
is_training=False,
num_frames=self._num_frames,
stride=self._stride,
num_test_clips=1,
min_resize=self._min_resize,
crop_size=self._crop_size,
num_crops=1)
image = tf.cast(image, tf.float32) # Use config.
features = {'image': image}
return features
def _preprocess_audio(self, audio):
features = {}
audio = tf.cast(audio, dtype=tf.float32) # Use config.
audio = video_input.preprocess_ops_3d.sample_sequence(
audio, 20, random=False, stride=1)
audio = tf.ensure_shape(
audio, self._params.task.validation_data.audio_feature_shape)
features['audio'] = audio
return features
@tf.function
def inference_from_tf_example(
self, encoded_inputs: tf.Tensor) -> Mapping[str, tf.Tensor]:
with tf.device('cpu:0'):
if self._output_audio:
inputs = tf.map_fn(
self._decode_tf_example, (encoded_inputs),
fn_output_signature={
video_input.IMAGE_KEY: tf.string,
self._params.task.validation_data.audio_feature: tf.float32
})
return self.serve(inputs['image'], inputs['audio'])
else:
inputs = tf.map_fn(
self._decode_tf_example, (encoded_inputs),
fn_output_signature={
video_input.IMAGE_KEY: tf.string,
})
return self.serve(inputs[video_input.IMAGE_KEY], tf.zeros([1, 1]))
@tf.function
def inference_from_image_tensors(
self, input_frames: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(input_frames, tf.zeros([1, 1]))
@tf.function
def inference_from_image_audio_tensors(
self, input_frames: tf.Tensor,
input_audio: tf.Tensor) -> Mapping[str, tf.Tensor]:
return self.serve(input_frames, input_audio)
@tf.function
def inference_from_image_bytes(self, inputs: tf.Tensor):
raise NotImplementedError(
'Video classification do not support image bytes input.')
def serve(self, input_frames: tf.Tensor, input_audio: tf.Tensor):
"""Cast image to float and run inference.
Args:
input_frames: uint8 Tensor of shape [batch_size, None, None, 3]
input_audio: float32
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
inputs = tf.map_fn(
self._preprocess_image, (input_frames),
fn_output_signature={
'image': tf.float32,
})
if self._output_audio:
inputs.update(
tf.map_fn(
self._preprocess_audio, (input_audio),
fn_output_signature={'audio': tf.float32}))
logits = self.inference_step(inputs)
if self.params.task.train_data.is_multilabel:
probs = tf.math.sigmoid(logits)
else:
probs = tf.nn.softmax(logits)
return {'logits': logits, 'probs': probs}
def get_inference_signatures(self, function_keys: Dict[Text, Text]):
"""Gets defined function signatures.
Args:
function_keys: A dictionary with keys as the function to create signature
for and values as the signature keys when returns.
Returns:
A dictionary with key as signature key and value as concrete functions
that can be used for tf.saved_model.save.
"""
signatures = {}
for key, def_name in function_keys.items():
if key == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size + [3],
dtype=tf.uint8,
name='INPUT_FRAMES')
signatures[
def_name] = self.inference_from_image_tensors.get_concrete_function(
input_signature)
elif key == 'frames_audio':
input_signature = [
tf.TensorSpec(
shape=[self._batch_size] + self._input_image_size + [3],
dtype=tf.uint8,
name='INPUT_FRAMES'),
tf.TensorSpec(
shape=[self._batch_size] +
self.params.task.train_data.audio_feature_shape,
dtype=tf.float32,
name='INPUT_AUDIO')
]
signatures[
def_name] = self.inference_from_image_audio_tensors.get_concrete_function(
input_signature)
elif key == 'serve_examples' or key == 'tf_example':
input_signature = tf.TensorSpec(
shape=[self._batch_size], dtype=tf.string)
signatures[
def_name] = self.inference_from_tf_example.get_concrete_function(
input_signature)
else:
raise ValueError('Unrecognized `input_type`')
return signatures
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment