"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "e928537042dc2212b369bf981345420170e7b7a4"
Commit 2ee42597 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 381516130
parent afb34072
...@@ -207,7 +207,7 @@ A brief look at each component in the code: ...@@ -207,7 +207,7 @@ A brief look at each component in the code:
* [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers. * [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
Other files: Other files:
* [beam_search_v1.py](beam_search_v1.py) contains the beam search implementation, which is used during model inference to find high scoring translations. * [beam_search.py](beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
### Model Trainer ### Model Trainer
[transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras. [transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
......
# Volumetric Models
**DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase.
This folder contains implementation of volumetric models, i.e., UNet 3D model,
for 3D semantic segmentation.
## Modeling
Following the style of TF-Vision, a UNet 3D model is implemented as a backbone
and a decoder.
## Backbone
The backbone is the left U-shape of the complete UNet model. It takes batch of
images as input, and outputs a dictionary in a form of `{level: features}`.
`features` in the output is a tensor of feature maps.
## Decoder
The decoder is the right U-shape of the complete UNet model. It takes the output
dictionary from the backbone and connects the feature maps from each level to
the decoder's decoding branches. The final output is the raw segmentation
predictions.
An additional head is attached to the output of the decoder to optionally
perform more operations and then generate the prediction map of logits.
The `factory.py` file builds and connects the backbone, decoder and head
together to form the complete UNet model.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones configurations."""
from typing import Optional, Sequence
import dataclasses
from official.modeling import hyperparams
@dataclasses.dataclass
class UNet3D(hyperparams.Config):
"""UNet3D config."""
model_id: int = 4
pool_size: Sequence[int] = (2, 2, 2)
kernel_size: Sequence[int] = (3, 3, 3)
base_filters: int = 32
use_batch_normalization: bool = True
@dataclasses.dataclass
class Backbone(hyperparams.OneOfConfig):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, one the of fields below.
resnet: resnet backbone config.
dilated_resnet: dilated resnet backbone for semantic segmentation config.
revnet: revnet backbone config.
efficientnet: efficientnet backbone config.
spinenet: spinenet backbone config.
mobilenet: mobilenet backbone config.
"""
type: Optional[str] = None
unet_3d: UNet3D = UNet3D()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Decoders configurations."""
from typing import Optional, Sequence
import dataclasses
from official.modeling import hyperparams
@dataclasses.dataclass
class UNet3DDecoder(hyperparams.Config):
"""UNet3D decoder config."""
model_id: int = 4
pool_size: Sequence[int] = (2, 2, 2)
kernel_size: Sequence[int] = (3, 3, 3)
use_batch_normalization: bool = True
use_deconvolution: bool = True
@dataclasses.dataclass
class Decoder(hyperparams.OneOfConfig):
"""Configuration for decoders.
Attributes:
type: 'str', type of decoder be used, on the of fields below.
fpn: fpn config.
"""
type: Optional[str] = None
unet_3d_decoder: UNet3DDecoder = UNet3DDecoder()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Semantic segmentation configuration definition."""
from typing import List, Optional, Union
import dataclasses
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.configs import common
from official.vision.beta.projects.volumetric_models.configs import backbones
from official.vision.beta.projects.volumetric_models.configs import decoders
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Input config for training."""
output_size: List[int] = dataclasses.field(default_factory=list)
input_size: List[int] = dataclasses.field(default_factory=list)
num_classes: int = 0
num_channels: int = 1
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
label_dtype: str = 'float32'
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
shuffle_buffer_size: int = 1000
cycle_length: int = 10
drop_remainder: bool = False
file_type: str = 'tfrecord'
@dataclasses.dataclass
class SegmentationHead3D(hyperparams.Config):
"""Segmentation head config."""
num_classes: int = 0
level: int = 1
num_convs: int = 0
num_filters: int = 256
upsample_factor: int = 1
output_logits: bool = True
@dataclasses.dataclass
class SemanticSegmentationModel3D(hyperparams.Config):
"""Semantic segmentation model config."""
num_classes: int = 0
num_channels: int = 1
input_size: List[int] = dataclasses.field(default_factory=list)
min_level: int = 3
max_level: int = 6
head: SegmentationHead3D = SegmentationHead3D()
backbone: backbones.Backbone = backbones.Backbone(
type='unet_3d', unet_3d=backbones.UNet3D())
decoder: decoders.Decoder = decoders.Decoder(
type='unet_3d_decoder', unet_3d_decoder=decoders.UNet3DDecoder())
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class Losses(hyperparams.Config):
# Supported `loss_type` are `adaptive` and `generalized`.
loss_type: str = 'adaptive'
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
report_per_class_metric: bool = False # Whether to report per-class metrics.
@dataclasses.dataclass
class SemanticSegmentation3DTask(cfg.TaskConfig):
"""The model config."""
model: SemanticSegmentationModel3D = SemanticSegmentationModel3D()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
evaluation: Evaluation = Evaluation()
train_input_partition_dims: List[int] = dataclasses.field(
default_factory=list)
eval_input_partition_dims: List[int] = dataclasses.field(default_factory=list)
init_checkpoint: Optional[str] = None
init_checkpoint_modules: Union[
str, List[str]] = 'all' # all, backbone, and/or decoder
@exp_factory.register_config_factory('seg_unet3d_test')
def seg_unet3d_test() -> cfg.ExperimentConfig:
"""Image segmentation on a dummy dataset with 3D UNet for testing purpose."""
train_batch_size = 2
eval_batch_size = 2
steps_per_epoch = 10
config = cfg.ExperimentConfig(
task=SemanticSegmentation3DTask(
model=SemanticSegmentationModel3D(
num_classes=2,
input_size=[32, 32, 32],
num_channels=2,
backbone=backbones.Backbone(
type='unet_3d', unet_3d=backbones.UNet3D(model_id=2)),
decoder=decoders.Decoder(
type='unet_3d_decoder',
unet_3d_decoder=decoders.UNet3DDecoder(model_id=2)),
head=SegmentationHead3D(num_convs=0, num_classes=2),
norm_activation=common.NormActivation(
activation='relu', use_sync_bn=False)),
train_data=DataConfig(
input_path='train.tfrecord',
num_classes=2,
input_size=[32, 32, 32],
num_channels=2,
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path='val.tfrecord',
num_classes=2,
input_size=[32, 32, 32],
num_channels=2,
is_training=False,
global_batch_size=eval_batch_size),
losses=Losses(loss_type='adaptive')),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=10,
validation_steps=10,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant',
'constant': {
'learning_rate': 0.000001
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for semantic_segmentation."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.modeling.hyperparams import config_definitions as cfg
from official.vision.beta.projects.volumetric_models.configs import semantic_segmentation_3d as exp_cfg
class ImageSegmentationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('seg_unet3d_test',),)
def test_semantic_segmentation_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.SemanticSegmentation3DTask)
self.assertIsInstance(config.task.model,
exp_cfg.SemanticSegmentationModel3D)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data parser and processing for 3D segmentation datasets."""
from typing import Any, Dict, Sequence, Tuple
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
class Decoder(decoder.Decoder):
"""A tf.Example decoder for segmentation task."""
def __init__(self,
image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label'):
self._keys_to_features = {
image_field_key: tf.io.FixedLenFeature([], tf.string, default_value=''),
label_field_key: tf.io.FixedLenFeature([], tf.string, default_value='')
}
def decode(self, serialized_example: tf.string) -> Dict[str, tf.Tensor]:
return tf.io.parse_single_example(serialized_example,
self._keys_to_features)
class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
input_size: Sequence[int],
num_classes: int,
num_channels: int = 3,
image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label',
dtype: str = 'float32',
label_dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
Args:
input_size: The input tensor size of [height, width, volume] of input
image.
num_classes: The number of classes to be segmented.
num_channels: The channel of input images.
image_field_key: A `str` of the key name to encoded image in TFExample.
label_field_key: A `str` of the key name to label in TFExample.
dtype: The data type. One of {`bfloat16`, `float32`, `float16`}.
label_dtype: The data type of input label.
"""
self._input_size = input_size
self._num_classes = num_classes
self._num_channels = num_channels
self._image_field_key = image_field_key
self._label_field_key = label_field_key
self._dtype = dtype
self._label_dtype = label_dtype
def _prepare_image_and_label(
self, data: Dict[str, Any]) -> Tuple[tf.Tensor, tf.Tensor]:
"""Prepares normalized image and label."""
image = tf.io.decode_raw(data[self._image_field_key],
tf.as_dtype(tf.float32))
label = tf.io.decode_raw(data[self._label_field_key],
tf.as_dtype(self._label_dtype))
image_size = list(self._input_size) + [self._num_channels]
image = tf.reshape(image, image_size)
label_size = list(self._input_size) + [self._num_classes]
label = tf.reshape(label, label_size)
image = tf.cast(image, dtype=self._dtype)
label = tf.cast(label, dtype=self._dtype)
# TPU doesn't support tf.int64 well, use tf.int32 directly.
if label.dtype == tf.int64:
label = tf.cast(label, dtype=tf.int32)
return image, label
def _parse_train_data(self, data: Dict[str,
Any]) -> Tuple[tf.Tensor, tf.Tensor]:
"""Parses data for training and evaluation."""
image, labels = self._prepare_image_and_label(data)
# Cast image as self._dtype
image = tf.cast(image, dtype=self._dtype)
return image, labels
def _parse_eval_data(self, data: Dict[str,
Any]) -> Tuple[tf.Tensor, tf.Tensor]:
"""Parses data for training and evaluation."""
image, labels = self._prepare_image_and_label(data)
# Cast image as self._dtype
image = tf.cast(image, dtype=self._dtype)
return image, labels
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for segmentation_input_3d.py."""
import os
from absl.testing import parameterized
import tensorflow as tf
from official.common import dataset_fn
from official.core import config_definitions as cfg
from official.core import input_reader
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.projects.volumetric_models.dataloaders import segmentation_input_3d
class InputReaderTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
data_dir = os.path.join(self.get_temp_dir(), 'data')
tf.io.gfile.makedirs(data_dir)
self._data_path = os.path.join(data_dir, 'data.tfrecord')
# pylint: disable=g-complex-comprehension
examples = [
tfexample_utils.create_3d_image_test_example(
image_height=32, image_width=32, image_volume=32, image_channel=2)
for _ in range(20)
]
# pylint: enable=g-complex-comprehension
tfexample_utils.dump_to_tfrecord(self._data_path, tf_examples=examples)
@parameterized.parameters(([32, 32, 32], 2, 2))
def testSegmentationInputReader(self, input_size, num_classes, num_channels):
params = cfg.DataConfig(
input_path=self._data_path, global_batch_size=2, is_training=False)
decoder = segmentation_input_3d.Decoder()
parser = segmentation_input_3d.Parser(
input_size=input_size,
num_classes=num_classes,
num_channels=num_channels)
reader = input_reader.InputReader(
params,
dataset_fn=dataset_fn.pick_dataset_fn('tfrecord'),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
dataset = reader.read()
iterator = iter(dataset)
image, labels = next(iterator)
# Checks image shape.
self.assertEqual(
list(image.numpy().shape),
[2, input_size[0], input_size[1], input_size[2], num_channels])
self.assertEqual(
list(labels.numpy().shape),
[2, input_size[0], input_size[1], input_size[2], num_classes])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metrics for segmentation."""
from typing import Optional
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.losses import segmentation_losses
class DiceScore:
"""Dice score metric for semantic segmentation.
This class follows the same function interface as tf.keras.metrics.Metric but
does not derive from tf.keras.metrics.Metric or utilize its functions. The
reason is a tf.keras.metrics.Metric object does not run well on CPU while
created on GPU, when running with MirroredStrategy. The same interface allows
for minimal change to the upstream tasks.
Attributes:
name: The name of the metric.
dtype: The dtype of the metric, for example, tf.float32.
"""
def __init__(self,
num_classes: int,
metric_type: Optional[str] = None,
per_class_metric: bool = False,
name: Optional[str] = None,
dtype: Optional[str] = None):
"""Constructs segmentation evaluator class.
Args:
num_classes: The number of classes.
metric_type: An optional `str` of type of dice scores.
per_class_metric: Whether to report per-class metric.
name: A `str`, name of the metric instance..
dtype: The data type of the metric result.
"""
self._num_classes = num_classes
self._per_class_metric = per_class_metric
self._dice_op_overall = segmentation_losses.SegmentationLossDiceScore(
metric_type=metric_type)
self._dice_scores_overall = tf.Variable(0.0)
self._count = tf.Variable(0.0)
if self._per_class_metric:
# Always use raw dice score for per-class metrics, so metric_type is None
# by default.
self._dice_op_per_class = segmentation_losses.SegmentationLossDiceScore()
self._dice_scores_per_class = [
tf.Variable(0.0) for _ in range(num_classes)
]
self.name = name
self.dtype = dtype
def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor):
"""Updates metric state.
Args:
y_true: The true labels of size [batch, width, height, volume,
num_classes].
y_pred: The prediction of size [batch, width, height, volume,
num_classes].
Raises:
ValueError: If number of classes from groundtruth label does not equal to
`num_classes`.
"""
if self._num_classes != y_true.get_shape()[-1]:
raise ValueError(
'The number of classes from groundtruth labels and `num_classes` '
'should equal, but they are {0} and {1}.'.format(
self._num_classes,
y_true.get_shape()[-1]))
self._count.assign_add(1.)
self._dice_scores_overall.assign_add(1 -
self._dice_op_overall(y_pred, y_true))
if self._per_class_metric:
for class_id in range(self._num_classes):
self._dice_scores_per_class[class_id].assign_add(
1 -
self._dice_op_per_class(y_pred[..., class_id], y_true[...,
class_id]))
def result(self) -> tf.Tensor:
"""Computes and returns the metric.
The first one is `generalized` or `adaptive` overall dice score, depending
on `metric_type`. If `per_class_metric` is True, `num_classes` elements are
also appended to the overall metric, as the per-class raw dice scores.
Returns:
The resulting dice scores.
"""
if self._per_class_metric:
dice_scores = [
tf.math.divide_no_nan(self._dice_scores_overall, self._count)
]
for class_id in range(self._num_classes):
dice_scores.append(
tf.math.divide_no_nan(self._dice_scores_per_class[class_id],
self._count))
return tf.stack(dice_scores)
else:
return tf.math.divide_no_nan(self._dice_scores_overall, self._count)
def reset_states(self):
"""Resets the metrcis to the initial state."""
self._count = tf.Variable(0.0)
self._dice_scores_overall = tf.Variable(0.0)
if self._per_class_metric:
for class_id in range(self._num_classes):
self._dice_scores_per_class[class_id] = tf.Variable(0.0)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for segmentation_losses.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.evaluation import segmentation_metrics
class SegmentationMetricsTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((1, 'generalized', 0.5, [0.74, 0.67]),
(1, 'adaptive', 0.5, [0.93, 0.67]),
(2, None, 0.5, [0.67, 0.67, 0.67]),
(3, 'generalized', 0.5, [0.7, 0.67, 0.67, 0.67]))
def test_forward_dice_score(self, num_classes, metric_type, output,
expected_score):
metric = segmentation_metrics.DiceScore(
num_classes=num_classes, metric_type=metric_type, per_class_metric=True)
y_pred = tf.constant(
output, shape=[2, 128, 128, 128, num_classes], dtype=tf.float32)
y_true = tf.ones(shape=[2, 128, 128, 128, num_classes], dtype=tf.float32)
metric.update_state(y_true=y_true, y_pred=y_pred)
actual_score = metric.result().numpy()
self.assertAllClose(
actual_score,
expected_score,
atol=1e-2,
msg='Output metric {} does not match expected metric {}.'.format(
actual_score, expected_score))
def test_num_classes_not_equal(self):
metric = segmentation_metrics.DiceScore(num_classes=4)
y_pred = tf.constant(0.5, shape=[2, 128, 128, 128, 2], dtype=tf.float32)
y_true = tf.ones(shape=[2, 128, 128, 128, 2], dtype=tf.float32)
with self.assertRaisesRegex(
ValueError,
'The number of classes from groundtruth labels and `num_classes` '
'should equal'):
metric.update_state(y_true=y_true, y_pred=y_pred)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for segmentation models."""
from typing import Optional, Sequence
import tensorflow as tf
class SegmentationLossDiceScore(object):
"""Semantic segmentation loss using generalized dice score.
Dice score (DSC) is a similarity measure that equals twice the number of
elements common to both sets divided by the sum of the number of elements
in each set. It is commonly used to evaluate segmentation performance to
measure the overlap of predicted and groundtruth regions.
(https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient)
Generalized dice score is the dice score weighted by the volume of groundtruth
labels per class. Adaptive dice score adds weights to generalized dice score.
It assigns larger weights to lower dice score, so that wrong predictions
contribute more to the total loss. Model will then be trained to focus more on
these hard examples.
"""
def __init__(self,
metric_type: Optional[str] = None,
axis: Optional[Sequence[int]] = (1, 2, 3)):
"""Initializes dice score loss object.
Args:
metric_type: An optional `str` specifying the type of the dice score to
compute. Compute generalized or adaptive dice score if metric type is
`generalized` or `adaptive`; otherwise compute original dice score.
axis: An optional sequence of `int` specifying the axis to perform reduce
ops for raw dice score.
"""
self._dice_score = 0
self._metric_type = metric_type
self._axis = axis
def __call__(self, logits: tf.Tensor, labels: tf.Tensor) -> tf.Tensor:
"""Computes and returns a loss based on 1 - dice score.
Args:
logits: A Tensor of the prediction.
labels: A Tensor of the groundtruth label.
Returns:
The loss value of (1 - dice score).
"""
labels = tf.cast(labels, logits.dtype)
if labels.get_shape().ndims < 2 or logits.get_shape().ndims < 2:
raise ValueError('The labels and logits must be at least rank 2.')
epsilon = tf.keras.backend.epsilon()
axis = list(range(len(logits.shape) - 1))
# Calculate intersections and unions per class.
intersection = tf.reduce_sum(labels * logits, axis=axis)
union = tf.reduce_sum(labels + logits, axis=axis)
if self._metric_type == 'generalized':
# Calculate the volume of groundtruth labels.
w = tf.math.reciprocal(
tf.square(tf.reduce_sum(labels, axis=axis)) + epsilon)
# Calculate the weighted dice score and normalizer.
dice = 2 * tf.reduce_sum(w * intersection) + epsilon
normalizer = tf.reduce_sum(w * union) + epsilon
dice = tf.cast(dice, dtype=tf.float32)
normalizer = tf.cast(normalizer, dtype=tf.float32)
return 1 - tf.reduce_mean(dice / normalizer)
elif self._metric_type == 'adaptive':
dice = 2.0 * (intersection + epsilon) / (union + epsilon)
# Calculate weights based on Dice scores.
weights = tf.exp(-1.0 * dice)
# Multiply weights by corresponding scores and get sum.
weighted_dice = tf.reduce_sum(weights * dice)
# Calculate normalization factor.
normalizer = tf.cast(tf.size(input=dice), dtype=tf.float32) * tf.exp(-1.0)
weighted_dice = tf.cast(weighted_dice, dtype=tf.float32)
return 1 - tf.reduce_mean(weighted_dice / normalizer)
else:
summation = tf.reduce_sum(
labels, axis=self._axis) + tf.reduce_sum(
logits, axis=self._axis)
dice = (2 * tf.reduce_sum(labels * logits, axis=self._axis) + epsilon) / (
summation + epsilon)
return 1 - tf.reduce_mean(dice)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for segmentation_losses.py."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.losses import segmentation_losses
class SegmentationLossDiceScoreTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters((None, 0.5, 0.3), ('generalized', 0.5, 0.3),
('adaptive', 0.5, 0.07))
def test_supported_loss(self, metric_type, output, expected_score):
loss = segmentation_losses.SegmentationLossDiceScore(
metric_type=metric_type)
logits = tf.constant(output, shape=[1, 128, 128, 128, 1], dtype=tf.float32)
labels = tf.ones(shape=[1, 128, 128, 128, 1], dtype=tf.float32)
actual_score = loss(logits=logits, labels=labels)
self.assertAlmostEqual(actual_score.numpy(), expected_score, places=1)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones package definition."""
from official.vision.beta.projects.volumetric_models.modeling.backbones.unet_3d import UNet3D
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of 3D UNet Model encoder part.
[1] Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf
Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650.
"""
from typing import Any, Mapping, Sequence
# Import libraries
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class UNet3D(tf.keras.Model):
"""Class to build 3D UNet backbone."""
def __init__(
self,
model_id: int,
input_specs: layers = layers.InputSpec(shape=[None, None, None, None, 3]),
pool_size: Sequence[int] = (2, 2, 2),
kernel_size: Sequence[int] = (3, 3, 3),
base_filters: int = 32,
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
activation: str = 'relu',
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_sync_bn: bool = False,
use_batch_normalization: bool = False,
**kwargs):
"""3D UNet backbone initialization function.
Args:
model_id: The depth of UNet3D backbone model. The greater the depth, the
more max pooling layers will be added to the model. Lowering the depth
may reduce the amount of memory required for training.
input_specs: The specs of the input tensor. It specifies a 5D input of
[batch, height, width, volume, channel] for `channel_last` data format
or [batch, channel, height, width, volume] for `channel_first` data
format.
pool_size: The pooling size for the max pooling operations.
kernel_size: The kernel size for 3D convolution.
base_filters: The number of filters that the first layer in the
convolution network will have. Following layers will contain a multiple
of this number. Lowering this number will likely reduce the amount of
memory required to train the model.
kernel_regularizer: A tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
activation: The name of the activation function.
norm_momentum: The normalization momentum for the moving average.
norm_epsilon: A float added to variance to avoid dividing by zero.
use_sync_bn: If True, use synchronized batch normalization.
use_batch_normalization: If set to True, use batch normalization after
convolution and before activation. Default to False.
**kwargs: Keyword arguments to be passed.
"""
self._model_id = model_id
self._input_specs = input_specs
self._pool_size = pool_size
self._kernel_size = kernel_size
self._activation = activation
self._base_filters = base_filters
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._use_sync_bn = use_sync_bn
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._kernel_regularizer = kernel_regularizer
self._use_batch_normalization = use_batch_normalization
# Build 3D UNet.
inputs = tf.keras.Input(
shape=input_specs.shape[1:], dtype=input_specs.dtype)
x = inputs
endpoints = {}
# Add levels with max pooling to downsample input.
for layer_depth in range(model_id):
# Two convoluions are applied sequentially without downsampling.
filter_num = base_filters * (2**layer_depth)
x2 = nn_blocks_3d.BasicBlock3DVolume(
filters=[filter_num, filter_num * 2],
strides=(1, 1, 1),
kernel_size=self._kernel_size,
kernel_regularizer=self._kernel_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
use_batch_normalization=self._use_batch_normalization)(
x)
if layer_depth < model_id - 1:
x = layers.MaxPool3D(
pool_size=pool_size,
strides=(2, 2, 2),
padding='valid',
data_format=tf.keras.backend.image_data_format())(
x2)
else:
x = x2
endpoints[str(layer_depth + 1)] = x2
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
super(UNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)
def get_config(self) -> Mapping[str, Any]:
return {
'model_id': self._model_id,
'pool_size': self._pool_size,
'kernel_size': self._kernel_size,
'activation': self._activation,
'base_filters': self._base_filters,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'use_sync_bn': self._use_sync_bn,
'kernel_regularizer': self._kernel_regularizer,
'use_batch_normalization': self._use_batch_normalization
}
@classmethod
def from_config(cls, config: Mapping[str, Any], custom_objects=None):
return cls(**config)
@property
def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""Returns a dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
@factory.register_backbone_builder('unet_3d')
def build_unet3d(
input_specs: tf.keras.layers.InputSpec,
backbone_config: hyperparams.Config,
norm_activation_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds 3D UNet backbone from a config."""
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'unet_3d', (f'Inconsistent backbone type '
f'{backbone_type}')
return UNet3D(
model_id=backbone_cfg.model_id,
input_specs=input_specs,
pool_size=backbone_cfg.pool_size,
base_filters=backbone_cfg.base_filters,
kernel_regularizer=l2_regularizer,
activation=norm_activation_config.activation,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
use_sync_bn=norm_activation_config.use_sync_bn,
use_batch_normalization=backbone_cfg.use_batch_normalization)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D UNet backbone."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d
class UNet3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
([128, 64], 4),
([256, 128], 6),
)
def test_network_creation(self, input_size, model_id):
"""Test creation of UNet3D family models."""
tf.keras.backend.set_image_data_format('channels_last')
network = unet_3d.UNet3D(model_id=model_id)
inputs = tf.keras.Input(
shape=(input_size[0], input_size[0], input_size[1], 3), batch_size=1)
endpoints = network(inputs)
for layer_depth in range(model_id):
self.assertAllEqual([
1, input_size[0] / 2**layer_depth, input_size[0] / 2**layer_depth,
input_size[1] / 2**layer_depth, 64 * 2**layer_depth
], endpoints[str(layer_depth + 1)].shape.as_list())
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id=4,
pool_size=(2, 2, 2),
kernel_size=(3, 3, 3),
activation='relu',
base_filters=32,
kernel_regularizer=None,
norm_momentum=0.99,
norm_epsilon=0.001,
use_sync_bn=False,
use_batch_normalization=True)
network = unet_3d.UNet3D(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = unet_3d.UNet3D.from_config(network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Decoders package definition."""
from official.vision.beta.projects.volumetric_models.modeling.decoders.unet_3d_decoder import UNet3DDecoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""factory method."""
from typing import Mapping
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import decoders
def build_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds decoder from a config.
Args:
input_specs: `dict` input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
Returns:
A tf.keras.Model instance of the decoder.
"""
decoder_type = model_config.decoder.type
decoder_cfg = model_config.decoder.get()
norm_activation_config = model_config.norm_activation
if decoder_type == 'identity':
decoder = None
elif decoder_type == 'unet_3d_decoder':
decoder = decoders.UNet3DDecoder(
model_id=decoder_cfg.model_id,
input_specs=input_specs,
pool_size=decoder_cfg.pool_size,
kernel_regularizer=l2_regularizer,
activation=norm_activation_config.activation,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
use_sync_bn=norm_activation_config.use_sync_bn,
use_batch_normalization=decoder_cfg.use_batch_normalization,
use_deconvolution=decoder_cfg.use_deconvolution)
else:
raise ValueError('Decoder {!r} not implement'.format(decoder_type))
return decoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of 3D UNet Model decoder part.
[1] Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf
Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650.
"""
from typing import Any, Sequence, Dict, Mapping
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling import nn_blocks_3d
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='Vision')
class UNet3DDecoder(tf.keras.Model):
"""Class to build 3D UNet decoder."""
def __init__(self,
model_id: int,
input_specs: Mapping[str, tf.TensorShape],
pool_size: Sequence[int] = (2, 2, 2),
kernel_size: Sequence[int] = (3, 3, 3),
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
activation: str = 'relu',
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_sync_bn: bool = False,
use_batch_normalization: bool = False,
use_deconvolution: bool = False,
**kwargs):
"""3D UNet decoder initialization function.
Args:
model_id: The depth of UNet3D backbone model. The greater the depth, the
more max pooling layers will be added to the model. Lowering the depth
may reduce the amount of memory required for training.
input_specs: The input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
pool_size: The pooling size for the max pooling operations.
kernel_size: The kernel size for 3D convolution.
kernel_regularizer: A tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
activation: The name of the activation function.
norm_momentum: The normalization momentum for the moving average.
norm_epsilon: A float added to variance to avoid dividing by zero.
use_sync_bn: If True, use synchronized batch normalization.
use_batch_normalization: If set to True, use batch normalization after
convolution and before activation. Default to False.
use_deconvolution: If set to True, the model will use transpose
convolution (deconvolution) instead of up-sampling. This increases the
amount memory required during training. Default to False.
**kwargs: Keyword arguments to be passed.
"""
self._config_dict = {
'model_id': model_id,
'input_specs': input_specs,
'pool_size': pool_size,
'kernel_size': kernel_size,
'kernel_regularizer': kernel_regularizer,
'activation': activation,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'use_sync_bn': use_sync_bn,
'use_batch_normalization': use_batch_normalization,
'use_deconvolution': use_deconvolution
}
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
self._use_batch_normalization = use_batch_normalization
if tf.keras.backend.image_data_format() == 'channels_last':
channel_dim = -1
else:
channel_dim = 1
# Build 3D UNet.
inputs = self._build_input_pyramid(input_specs, model_id)
# Add levels with up-convolution or up-sampling.
x = inputs[str(model_id)]
for layer_depth in range(model_id - 1, 0, -1):
# Apply deconvolution or upsampling.
if use_deconvolution:
x = layers.Conv3DTranspose(
filters=x.get_shape().as_list()[channel_dim],
kernel_size=pool_size,
strides=(2, 2, 2))(
x)
else:
x = layers.UpSampling3D(size=pool_size)(x)
# Concatenate upsampled features with input features from one layer up.
x = tf.concat([x, tf.cast(inputs[str(layer_depth)], dtype=x.dtype)],
axis=channel_dim)
filter_num = inputs[str(layer_depth)].get_shape().as_list()[channel_dim]
x = nn_blocks_3d.BasicBlock3DVolume(
filters=[filter_num, filter_num],
strides=(1, 1, 1),
kernel_size=kernel_size,
kernel_regularizer=kernel_regularizer,
activation=activation,
use_sync_bn=use_sync_bn,
norm_momentum=norm_momentum,
norm_epsilon=norm_epsilon,
use_batch_normalization=use_batch_normalization)(
x)
feats = {'1': x}
self._output_specs = {l: feats[l].get_shape() for l in feats}
super(UNet3DDecoder, self).__init__(inputs=inputs, outputs=feats, **kwargs)
def _build_input_pyramid(self, input_specs: Dict[str, tf.TensorShape],
depth: int) -> Dict[str, tf.Tensor]:
"""Builds input pyramid features."""
assert isinstance(input_specs, dict)
if len(input_specs.keys()) > depth:
raise ValueError(
'Backbone depth should be equal to 3D UNet decoder\'s depth.')
inputs = {}
for level, spec in input_specs.items():
inputs[level] = tf.keras.Input(shape=spec[1:])
return inputs
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config: Mapping[str, Any], custom_objects=None):
return cls(**config)
@property
def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D UNet decoder."""
# Import libraries
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.volumetric_models.modeling.backbones import unet_3d
from official.vision.beta.projects.volumetric_models.modeling.decoders import unet_3d_decoder
class UNet3DDecoderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
([128, 64], 4),
([256, 128], 6),
)
def test_network_creation(self, input_size, model_id):
"""Test creation of UNet3D family models."""
tf.keras.backend.set_image_data_format('channels_last')
# `input_size` consists of [spatial size, volume size].
inputs = tf.keras.Input(
shape=(input_size[0], input_size[0], input_size[1], 3), batch_size=1)
backbone = unet_3d.UNet3D(model_id=model_id)
network = unet_3d_decoder.UNet3DDecoder(
model_id=model_id, input_specs=backbone.output_specs)
endpoints = backbone(inputs)
feats = network(endpoints)
self.assertIn('1', feats)
self.assertAllEqual([1, input_size[0], input_size[0], input_size[1], 64],
feats['1'].shape.as_list())
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
kwargs = dict(
model_id=4,
input_specs=unet_3d.UNet3D(model_id=4).output_specs,
pool_size=(2, 2, 2),
kernel_size=(3, 3, 3),
kernel_regularizer=None,
activation='relu',
norm_momentum=0.99,
norm_epsilon=0.001,
use_sync_bn=False,
use_batch_normalization=True,
use_deconvolution=True)
network = unet_3d_decoder.UNet3DDecoder(**kwargs)
expected_config = dict(kwargs)
self.assertEqual(network.get_config(), expected_config)
# Create another network object from the first object's config.
new_network = unet_3d_decoder.UNet3DDecoder.from_config(
network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory methods to build models."""
# Import libraries
import tensorflow as tf
from official.modeling import hyperparams
from official.vision.beta.modeling import segmentation_model
from official.vision.beta.modeling.backbones import factory as backbone_factory
from official.vision.beta.projects.volumetric_models.modeling.decoders import factory as decoder_factory
from official.vision.beta.projects.volumetric_models.modeling.heads import segmentation_heads_3d
def build_segmentation_model_3d(
input_specs: tf.keras.layers.InputSpec,
model_config: hyperparams.Config,
l2_regularizer: tf.keras.regularizers.Regularizer = None) -> tf.keras.Model:
"""Builds Segmentation model."""
norm_activation_config = model_config.norm_activation
backbone = backbone_factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
decoder = decoder_factory.build_decoder(
input_specs=backbone.output_specs,
model_config=model_config,
l2_regularizer=l2_regularizer)
head_config = model_config.head
head = segmentation_heads_3d.SegmentationHead3D(
num_classes=model_config.num_classes,
level=head_config.level,
num_convs=head_config.num_convs,
num_filters=head_config.num_filters,
upsample_factor=head_config.upsample_factor,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer,
output_logits=head_config.output_logits)
model = segmentation_model.SegmentationModel(backbone, decoder, head)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment