"docs/basic_usage/offline_engine_api.ipynb" did not exist on "acd1a15921f7405d611a607fee73274d8e77bedc"
Commit 2ee42597 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381516130
parent afb34072
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.modeling import factory
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d # pylint: disable=unused-import
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(((128, 128, 128), 5e-5), ((64, 64, 64), None))
def test_unet3d_builder(self, input_size, weight_decay):
num_classes = 3
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], input_size[2], 3])
model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
l2_regularizer = (
tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
model = factory.build_segmentation_model_3d(
input_specs=input_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
self.assertIsInstance(
model, tf.keras.Model,
'Output should be a tf.keras.Model instance but got %s' % type(model))
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Segmentation heads."""
from typing import Any, Union, Sequence, Mapping
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class SegmentationHead3D(tf.keras.layers.Layer):
"""Segmentation head for 3D input."""
def __init__(self,
num_classes: int,
level: Union[int, str],
num_convs: int = 2,
num_filters: int = 256,
upsample_factor: int = 1,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_logits: bool = True,
**kwargs):
"""Initialize params to build segmentation head.
Args:
num_classes: `int` number of mask classification categories. The number of
classes does not include background class.
level: `int` or `str`, level to use to build segmentation head.
num_convs: `int` number of stacked convolution before the last prediction
layer.
num_filters: `int` number to specify the number of filters used. Default
is 256.
upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied.
activation: `string`, indicating which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: `bool`, whether to use synchronized batch normalization
across different replicas.
norm_momentum: `float`, the momentum parameter of the normalization
layers.
norm_epsilon: `float`, the epsilon parameter of the normalization layers.
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
output_logits: A `bool` of whether to output logits or not. Default
is True. If set to False, output softmax.
**kwargs: other keyword arguments passed to Layer.
"""
super(SegmentationHead3D, self).__init__(**kwargs)
self._config_dict = {
'num_classes': num_classes,
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'upsample_factor': upsample_factor,
'activation': activation,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'output_logits': output_logits
}
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape: Union[tf.TensorShape, Sequence[tf.TensorShape]]):
"""Creates the variables of the segmentation head."""
conv_op = tf.keras.layers.Conv3D
conv_kwargs = {
'kernel_size': (3, 3, 3),
'padding': 'same',
'use_bias': False,
'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01),
'kernel_regularizer': self._config_dict['kernel_regularizer'],
}
final_kernel_size = (1, 1, 1)
bn_op = (
tf.keras.layers.experimental.SyncBatchNormalization
if self._config_dict['use_sync_bn'] else
tf.keras.layers.BatchNormalization)
bn_kwargs = {
'axis': self._bn_axis,
'momentum': self._config_dict['norm_momentum'],
'epsilon': self._config_dict['norm_epsilon'],
}
# Segmentation head layers.
self._convs = []
self._norms = []
for i in range(self._config_dict['num_convs']):
conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append(
conv_op(
name=conv_name,
filters=self._config_dict['num_filters'],
**conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i)
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
self._classifier = conv_op(
name='segmentation_output',
filters=self._config_dict['num_classes'],
kernel_size=final_kernel_size,
padding='valid',
activation=None,
bias_initializer=tf.zeros_initializer(),
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer'])
super(SegmentationHead3D, self).build(input_shape)
def call(self, backbone_output: Mapping[str, tf.Tensor],
decoder_output: Mapping[str, tf.Tensor]) -> tf.Tensor:
"""Forward pass of the segmentation head.
Args:
backbone_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
Returns:
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
"""
x = decoder_output[str(self._config_dict['level'])]
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
x = norm(x)
x = self._activation(x)
x = tf.keras.layers.UpSampling3D(size=self._config_dict['upsample_factor'])(
x)
x = self._classifier(x)
return x if self._config_dict['output_logits'] else tf.keras.layers.Softmax(
dtype='float32')(
x)
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config: Mapping[str, Any]):
return cls(**config)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for segmentation_heads.py."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 0),
(2, 1),
)
def test_forward(self, level, num_convs):
head = segmentation_heads_3d.SegmentationHead3D(
num_classes=10, level=level, num_convs=num_convs)
backbone_features = {
'1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16),
}
decoder_features = {
'1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16),
}
logits = head(backbone_features, decoder_features)
if str(level) in decoder_features:
self.assertAllEqual(logits.numpy().shape, [
2, decoder_features[str(level)].shape[1],
decoder_features[str(level)].shape[2],
decoder_features[str(level)].shape[3], 10
])
def test_serialize_deserialize(self):
head = segmentation_heads_3d.SegmentationHead3D(num_classes=10, level=3)
config = head.get_config()
new_head = segmentation_heads_3d.SegmentationHead3D.from_config(config)
self.assertAllEqual(head.get_config(), new_head.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for neural networks."""
from typing import Sequence, Union
# Import libraries
import tensorflow as tf
from official.modeling import tf_utils
from official.vision.beta.modeling.layers import nn_layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class BasicBlock3DVolume(tf.keras.layers.Layer):
"""A basic 3d convolution block."""
def __init__(self,
filters: Union[int, Sequence[int]],
strides: Union[int, Sequence[int]],
kernel_size: Union[int, Sequence[int]],
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_batch_normalization: bool = False,
**kwargs):
"""Creates a basic 3d convolution block applying one or more convolutions.
Args:
filters: A list of `int` numbers or an `int` number of filters. Given an
`int` input, a single convolution is applied; otherwise a series of
convolutions are applied.
strides: An integer or tuple/list of 3 integers, specifying the strides of
the convolution along each spatial dimension. Can be a single integer to
specify the same value for all spatial dimensions.
kernel_size: An integer or tuple/list of 3 integers, specifying the depth,
height and width of the 3D convolution window. Can be a single integer
to specify the same value for all spatial dimensions.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
use_batch_normalization: Wheher to use batch normalizaion or not.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
if isinstance(filters, int):
self._filters = [filters]
else:
self._filters = filters
self._strides = strides
self._kernel_size = kernel_size
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._use_batch_normalization = use_batch_normalization
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape: tf.TensorShape):
"""Builds the basic 3d convolution block."""
self._convs = []
self._norms = []
for filters in self._filters:
self._convs.append(
tf.keras.layers.Conv3D(
filters=filters,
kernel_size=self._kernel_size,
strides=self._strides,
padding='same',
data_format=tf.keras.backend.image_data_format(),
activation=None))
self._norms.append(self._norm(axis=self._bn_axis))
super(BasicBlock3DVolume, self).build(input_shape)
def get_config(self):
"""Returns the config of the basic 3d convolution block."""
config = {
'filters': self._filters,
'strides': self._strides,
'kernel_size': self._kernel_size,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'use_batch_normalization': self._use_batch_normalization
}
base_config = super(BasicBlock3DVolume, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs: tf.Tensor, training: bool = None) -> tf.Tensor:
"""Runs forward pass on the input tensor."""
x = inputs
for conv, norm in zip(self._convs, self._norms):
x = conv(x)
if self._use_batch_normalization:
x = norm(x)
x = self._activation_fn(x)
return x
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResidualBlock3DVolume(tf.keras.layers.Layer):
"""A residual 3d block."""
def __init__(self,
filters,
strides,
use_projection=False,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A residual 3d block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters,
out_filters=self._filters,
se_ratio=self._se_ratio,
use_3d_input=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
super(ResidualBlock3DVolume, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResidualBlock3DVolume, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
return self._activation_fn(x + shortcut)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock3DVolume(tf.keras.layers.Layer):
"""A standard bottleneck block."""
def __init__(self,
filters,
strides,
dilation_rate=1,
use_projection=False,
se_ratio=None,
stochastic_depth_drop_rate=None,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A standard bottleneck 3d block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int` dilation_rate of convolutions. Default to 1.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._filters = filters
self._strides = strides
self._dilation_rate = dilation_rate
self._use_projection = use_projection
self._se_ratio = se_ratio
self._use_sync_bn = use_sync_bn
self._activation = activation
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv3D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv3D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
dilation_rate=self._dilation_rate,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv3 = tf.keras.layers.Conv3D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
if self._se_ratio and self._se_ratio > 0 and self._se_ratio <= 1:
self._squeeze_excitation = nn_layers.SqueezeExcitation(
in_filters=self._filters * 4,
out_filters=self._filters * 4,
se_ratio=self._se_ratio,
use_3d_input=True,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
else:
self._squeeze_excitation = None
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
super(BottleneckBlock3DVolume, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'dilation_rate': self._dilation_rate,
'use_projection': self._use_projection,
'se_ratio': self._se_ratio,
'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(BottleneckBlock3DVolume, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._conv3(x)
x = self._norm3(x)
if self._squeeze_excitation:
x = self._squeeze_excitation(x)
if self._stochastic_depth:
x = self._stochastic_depth(x, training=training)
return self._activation_fn(x + shortcut)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D volumeric convoluion blocks."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d
class NNBlocks3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((128, 128, 32, 1), (256, 256, 16, 2))
def test_bottleneck_block_3d_volume_creation(self, spatial_size, volume_size,
filters, strides):
inputs = tf.keras.Input(
shape=(spatial_size, spatial_size, volume_size, filters * 4),
batch_size=1)
block = nn_blocks_3d.BottleneckBlock3DVolume(
filters=filters, strides=strides, use_projection=True)
features = block(inputs)
self.assertAllEqual([
1, spatial_size // strides, spatial_size // strides,
volume_size // strides, filters * 4
], features.shape.as_list())
@parameterized.parameters((128, 128, 32, 1), (256, 256, 64, 2))
def test_residual_block_3d_volume_creation(self, spatial_size, volume_size,
filters, strides):
inputs = tf.keras.Input(
shape=(spatial_size, spatial_size, volume_size, filters), batch_size=1)
block = nn_blocks_3d.ResidualBlock3DVolume(
filters=filters, strides=strides, use_projection=True)
features = block(inputs)
self.assertAllEqual([
1, spatial_size // strides, spatial_size // strides,
volume_size // strides, filters
], features.shape.as_list())
@parameterized.parameters((128, 128, 64, 1, 3), (256, 256, 128, 2, 1))
def test_basic_block_3d_volume_creation(self, spatial_size, volume_size,
filters, strides, kernel_size):
inputs = tf.keras.Input(
shape=(spatial_size, spatial_size, volume_size, filters), batch_size=1)
block = nn_blocks_3d.BasicBlock3DVolume(
filters=filters, strides=strides, kernel_size=kernel_size)
features = block(inputs)
self.assertAllEqual([
1, spatial_size // strides, spatial_size // strides,
volume_size // strides, filters
], features.shape.as_list())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for segmentation network."""
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.modeling import decoders
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
class SegmentationNetworkUNet3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
([32, 32], 4),
([64, 64], 4),
([64, 64], 2),
([128, 64], 2),
)
def test_segmentation_network_unet3d_creation(self, input_size, depth):
"""Test for creation of a segmentation network."""
num_classes = 2
inputs = np.random.rand(2, input_size[0], input_size[0], input_size[1], 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.UNet3D(model_id=depth)
decoder = decoders.UNet3DDecoder(
model_id=depth, input_specs=backbone.output_specs)
head = segmentation_heads_3d.SegmentationHead3D(
num_classes, level=1, num_convs=0)
model = segmentation_model.SegmentationModel(
backbone=backbone, decoder=decoder, head=head)
logits = model(inputs)
self.assertAllEqual(
[2, input_size[0], input_size[0], input_size[1], num_classes],
logits.numpy().shape)
def test_serialize_deserialize(self):
"""Validate the network can be serialized and deserialized."""
num_classes = 3
backbone = backbones.UNet3D(model_id=4)
decoder = decoders.UNet3DDecoder(
model_id=4, input_specs=backbone.output_specs)
head = segmentation_heads_3d.SegmentationHead3D(
num_classes, level=1, num_convs=0)
model = segmentation_model.SegmentationModel(
backbone=backbone, decoder=decoder, head=head)
config = model.get_config()
new_model = segmentation_model.SegmentationModel.from_config(config)
# Validate that the config can be forced to JSON.
_ = new_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(model.get_config(), new_model.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.volumetric_models.modeling import backbones
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Volumetric model export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=1 \
--input_image_size=128,128,128 \
--num_channels=1
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
from official.vision.beta.serving import export_saved_model_lib
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_list(
'input_image_size', None,
'The comma-separated string of three integers representing the '
'height, width and depth of the input to the model.')
flags.DEFINE_integer('num_channels', 1,
'The number of channels of input image.')
flags.register_validator(
'input_image_size',
lambda value: value is not None and len(value) == 3,
message='--input_image_size must be comma-separated string of three '
'integers representing the height, width and depth of the input to '
'the model.')
def main(_):
flags.mark_flag_as_required('export_dir')
flags.mark_flag_as_required('checkpoint_path')
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
input_image_size = FLAGS.input_image_size
export_module = semantic_segmentation_3d.SegmentationModule(
params=params,
batch_size=1,
input_image_size=input_image_size,
num_channels=FLAGS.num_channels)
export_saved_model_lib.export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=input_image_size,
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir,
num_channels=FLAGS.num_channels,
export_module=export_module,
export_checkpoint_subdir='checkpoint',
export_saved_model_subdir='saved_model')
if __name__ == '__main__':
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""3D semantic segmentation input and model functions for serving/inference."""
from typing import Mapping
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import factory
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d # pylint: disable=unused-import
from official.vision.beta.serving import export_base
class SegmentationModule(export_base.ExportModule):
"""Segmentation Module."""
def _build_model(self) -> tf.keras.Model:
"""Builds and returns a segmentation model."""
num_channels = self.params.task.model.num_channels
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [num_channels])
return factory.build_segmentation_model_3d(
input_specs=input_specs,
model_config=self.params.task.model,
l2_regularizer=None)
def serve(
self, images: tf.Tensor) -> Mapping[str, tf.Tensor]:
"""Casts an image tensor to float and runs inference.
Args:
images: A uint8 tf.Tensor of shape [batch_size, None, None, None,
num_channels].
Returns:
A dictionary holding segmentation outputs.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
outputs = self.inference_step(images)
output_key = 'logits' if self.params.task.model.head.output_logits else 'probs'
return {output_key: outputs}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for semantic_segmentation_3d export lib."""
import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.core import exp_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg # pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d # pylint: disable=unused-import
from official.vision.beta.projects.volumetric_models.serving import semantic_segmentation_3d
class SemanticSegmentationExportTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._num_channels = 2
self._input_image_size = [32, 32, 32]
self._params = exp_factory.get_exp_config('seg_unet3d_test')
input_shape = self._input_image_size + [self._num_channels]
self._image_array = np.zeros(shape=input_shape, dtype=np.uint8)
def _get_segmentation_module(self):
return semantic_segmentation_3d.SegmentationModule(
self._params,
batch_size=1,
input_image_size=self._input_image_size,
num_channels=self._num_channels)
def _export_from_module(self, module, input_type: str, save_directory: str):
signatures = module.get_inference_signatures(
{input_type: 'serving_default'})
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
image_tensor = tf.convert_to_tensor(self._image_array, dtype=tf.uint8)
return tf.expand_dims(image_tensor, axis=0)
if input_type == 'image_bytes':
return [self._image_array.tostring()]
if input_type == 'tf_example':
encoded_image = self._image_array.tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_image])),
})).SerializeToString()
return [example]
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type: str = 'image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_segmentation_module()
self._export_from_module(module, input_type, tmp_dir)
# Check if model is successfully exported.
self.assertTrue(tf.io.gfile.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(
tf.io.gfile.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(
tf.io.gfile.exists(
os.path.join(tmp_dir, 'variables',
'variables.data-00000-of-00001')))
# Get inference signature from loaded SavedModel.
imported = tf.saved_model.load(tmp_dir)
segmentation_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
image_tensor = self._get_dummy_input(input_type='image_tensor')
# Perform inference using loaded SavedModel and model instance and check if
# outputs equal.
expected_output = module.model(image_tensor, training=False)
out = segmentation_fn(tf.constant(images))
self.assertAllClose(out['logits'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image segmentation task definition."""
from typing import Any, Dict, Mapping, Optional, Sequence, Union
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
from official.vision.beta.projects.volumetric_models.dataloaders import segmentation_input_3d
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.losses import segmentation_losses
from official.vision.beta.projects.volumetric_models.modeling import factory
@task_factory.register_task_cls(exp_cfg.SemanticSegmentation3DTask)
class SemanticSegmentation3DTask(base_task.Task):
"""A task for semantic segmentation."""
def build_model(self) -> tf.keras.Model:
"""Builds segmentation model."""
input_specs = tf.keras.layers.InputSpec(
shape=[None] + self.task_config.model.input_size +
[self.task_config.model.num_channels],
dtype=self.task_config.train_data.dtype)
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
model = factory.build_segmentation_model_3d(
input_specs=input_specs,
model_config=self.task_config.model,
l2_regularizer=l2_regularizer)
# Create a dummy input and call model instance to initialize the model. This
# is needed when launching multiple experiments using the same model
# directory. Since there is already a trained model, forward pass will not
# run and the model will never be built. This is only done when spatial
# partitioning is not enabled; otherwise it will fail with OOM due to
# extremely large input.
if (not self.task_config.train_input_partition_dims) and (
not self.task_config.eval_input_partition_dims):
dummy_input = tf.random.uniform(shape=[1] + list(input_specs.shape[1:]))
_ = model(dummy_input)
return model
def initialize(self, model: tf.keras.Model):
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
ckpt_dir_or_file = self.task_config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
# Restoring checkpoint.
if 'all' in self.task_config.init_checkpoint_modules:
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.assert_consumed()
else:
ckpt_items = {}
if 'backbone' in self.task_config.init_checkpoint_modules:
ckpt_items.update(backbone=model.backbone)
if 'decoder' in self.task_config.init_checkpoint_modules:
ckpt_items.update(decoder=model.decoder)
ckpt = tf.train.Checkpoint(**ckpt_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None) -> tf.data.Dataset:
"""Builds classification input."""
decoder = segmentation_input_3d.Decoder(
image_field_key=params.image_field_key,
label_field_key=params.label_field_key)
parser = segmentation_input_3d.Parser(
input_size=params.input_size,
num_classes=params.num_classes,
num_channels=params.num_channels,
image_field_key=params.image_field_key,
label_field_key=params.label_field_key,
dtype=params.dtype,
label_dtype=params.label_dtype)
reader = input_reader.InputReader(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses=None) -> tf.Tensor:
"""Segmentation loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
segmentation_loss_fn = segmentation_losses.SegmentationLossDiceScore(
metric_type='adaptive')
total_loss = segmentation_loss_fn(model_outputs, labels)
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self,
training: bool = True) -> Sequence[tf.keras.metrics.Metric]:
"""Gets streaming metrics for training/validation."""
metrics = []
num_classes = self.task_config.model.num_classes
if training:
metrics.extend([
tf.keras.metrics.CategoricalAccuracy(
name='train_categorical_accuracy', dtype=tf.float32)
])
else:
self.metrics = [
segmentation_metrics.DiceScore(
num_classes=num_classes,
metric_type='generalized',
per_class_metric=self.task_config.evaluation
.report_per_class_metric,
name='val_generalized_dice',
dtype=tf.float32)
]
return metrics
def train_step(
self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[Sequence[tf.keras.metrics.Metric]] = None
) -> Dict[Any, Any]:
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.train_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features = strategy.experimental_split_to_logical_devices(
features, input_partition_dims)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.model.head.output_logits:
outputs = tf.nn.softmax(outputs)
# Computes per-replica loss.
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
# Compute all metrics within strategy scope for training.
if metrics:
labels = tf.cast(labels, tf.float32)
outputs = tf.cast(outputs, tf.float32)
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs
def validation_step(
self,
inputs,
model: tf.keras.Model,
metrics: Optional[Sequence[tf.keras.metrics.Metric]] = None
) -> Dict[Any, Any]:
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
input_partition_dims = self.task_config.eval_input_partition_dims
if input_partition_dims:
strategy = tf.distribute.get_strategy()
features = strategy.experimental_split_to_logical_devices(
features, input_partition_dims)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
if self.task_config.model.head.output_logits:
outputs = tf.nn.softmax(outputs)
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
logs = {self.loss: loss}
# Compute dice score metrics on CPU.
for metric in self.metrics:
labels = tf.cast(labels, tf.float32)
outputs = tf.cast(outputs, tf.float32)
logs.update({metric.name: (labels, outputs)})
return logs
def inference_step(self, inputs, model: tf.keras.Model) -> tf.Tensor:
"""Performs the forward step."""
return model(inputs, training=False)
def aggregate_logs(
self,
state: Optional[Sequence[Union[segmentation_metrics.DiceScore,
tf.keras.metrics.Metric]]] = None,
step_outputs: Optional[Mapping[str, Any]] = None
) -> Sequence[tf.keras.metrics.Metric]:
"""Aggregates statistics to compute metrics over training.
Args:
state: A sequence of tf.keras.metrics.Metric objects. Each element records
a metric.
step_outputs: A dictionary of [metric_name, (labels, output)] from a step.
Returns:
An updated sequence of tf.keras.metrics.Metric objects.
"""
if state is None:
for metric in self.metrics:
metric.reset_states()
state = self.metrics
for metric in self.metrics:
labels = step_outputs[metric.name][0]
predictions = step_outputs[metric.name][1]
# If `step_output` is distributed, it contains a tuple of Tensors instead
# of a single Tensor, so we need to concatenate them along the batch
# dimension in this case to have a single Tensor.
if isinstance(labels, tuple):
labels = tf.concat(list(labels), axis=0)
if isinstance(predictions, tuple):
predictions = tf.concat(list(predictions), axis=0)
labels = tf.cast(labels, tf.float32)
predictions = tf.cast(predictions, tf.float32)
metric.update_state(labels, predictions)
return state
def reduce_aggregated_logs(
self,
aggregated_logs: Optional[Mapping[str, Any]] = None,
global_step: Optional[tf.Tensor] = None) -> Mapping[str, float]:
"""Reduces logs to obtain per-class metrics if needed.
Args:
aggregated_logs: An optional dictionary containing aggregated logs.
global_step: An optional `tf.Tensor` of current global training steps.
Returns:
The reduced logs containing per-class metrics and overall metrics.
Raises:
ValueError: If `self.metrics` does not contain exactly 1 metric object.
"""
result = {}
if len(self.metrics) != 1:
raise ValueError('Exact one metric must be present, but {0} are '
'present.'.format(len(self.metrics)))
metric = self.metrics[0].result().numpy()
if self.task_config.evaluation.report_per_class_metric:
for i, metric_val in enumerate(metric):
metric_name = self.metrics[0].name + '/class_{0}'.format(
i - 1) if i > 0 else self.metrics[0].name
result.update({metric_name: metric_val})
return result
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for semantic segmentation task."""
# pylint: disable=unused-import
import functools
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.modeling import optimization
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d
from official.vision.beta.projects.volumetric_models.tasks import semantic_segmentation_3d as img_seg_task
class SemanticSegmentationTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.create_3d_image_test_example(
image_height=32, image_width=32, image_volume=32, image_channel=2)
for _ in range(20)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
@parameterized.parameters(('seg_unet3d_test',))
def test_task(self, config_name):
config = exp_factory.get_exp_config(config_name)
config.task.train_data.input_path = self._data_path
config.task.train_data.global_batch_size = 4
config.task.train_data.shuffle_buffer_size = 4
config.task.validation_data.input_path = self._data_path
config.task.validation_data.shuffle_buffer_size = 4
config.task.evaluation.report_per_class_metric = True
task = img_seg_task.SemanticSegmentation3DTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
# Check if training loss is produced.
self.assertIn('loss', logs)
# Obtain distributed outputs.
distributed_outputs = strategy.run(
functools.partial(
task.validation_step,
model=model,
metrics=task.build_metrics(training=False)),
args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
# Check if validation loss is produced.
self.assertIn('loss', outputs)
# Check if state is updated.
state = task.aggregate_logs(state=None, step_outputs=outputs)
self.assertLen(state, 1)
self.assertIsInstance(state[0], segmentation_metrics.DiceScore)
# Check if all metrics are produced.
result = task.reduce_aggregated_logs(aggregated_logs={}, global_step=1)
self.assertIn('val_generalized_dice', result)
self.assertIn('val_generalized_dice/class_0', result)
self.assertIn('val_generalized_dice/class_1', result)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from absl import app
import gin # pylint: disable=unused-import
from official.common import flags as tfm_flags
from official.vision.beta import train
from official.vision.beta.projects.volumetric_models import registry_imports # pylint: disable=unused-import
def main(_):
train.main(_)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train."""
import json
import os
from absl import flags
from absl import logging
from absl.testing import flagsaver
import tensorflow as tf
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models import train as train_lib
FLAGS = flags.FLAGS
class TrainTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.create_3d_image_test_example(
image_height=32, image_width=32, image_volume=32, image_channel=2)
for _ in range(2)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
def test_run(self):
saved_flag_values = flagsaver.save_flag_values()
train_lib.tfm_flags.define_flags()
FLAGS.mode = 'train'
FLAGS.model_dir = self._model_dir
FLAGS.experiment = 'seg_unet3d_test'
logging.info('Test pipeline correctness.')
params_override = json.dumps({
'runtime': {
'mixed_precision_dtype': 'float32',
},
'trainer': {
'train_steps': 1,
'validation_steps': 1,
},
'task': {
'model': {
'backbone': {
'unet_3d': {
'model_id': 4,
},
},
'decoder': {
'unet_3d_decoder': {
'model_id': 4,
},
},
},
'train_data': {
'input_path': self._data_path,
'file_type': 'tfrecord',
'global_batch_size': 2,
},
'validation_data': {
'input_path': self._data_path,
'file_type': 'tfrecord',
'global_batch_size': 2,
}
}
})
FLAGS.params_override = params_override
train_lib.main('unused_args')
FLAGS.mode = 'eval'
with train_lib.gin.unlock_config():
train_lib.main('unused_args')
flagsaver.restore_flag_values(saved_flag_values)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment