Commit aca51294 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 318129893
parent e0dade52
...@@ -19,9 +19,10 @@ In the near future, we will add: ...@@ -19,9 +19,10 @@ In the near future, we will add:
* State-of-the-art language understanding models: * State-of-the-art language understanding models:
More members in Transformer family More members in Transformer family
* Start-of-the-art image classification models: * State-of-the-art image classification models:
EfficientNet, MnasNet, and variants EfficientNet, MnasNet, and variants
* A set of excellent objection detection models. * State-of-the-art objection detection and instance segmentation models:
RetinaNet, Mask R-CNN, SpineNet, and variants
## Table of Contents ## Table of Contents
...@@ -52,6 +53,7 @@ In the near future, we will add: ...@@ -52,6 +53,7 @@ In the near future, we will add:
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) | | [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
......
...@@ -48,6 +48,22 @@ so the checkpoints are not compatible. ...@@ -48,6 +48,22 @@ so the checkpoints are not compatible.
We will unify the implementation soon. We will unify the implementation soon.
### Train a SpineNet-49 based RetinaNet.
```bash
TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
python3 ~/models/official/vision/detection/main.py \
--strategy_type=tpu \
--tpu="${TPU_NAME?}" \
--model_dir="${MODEL_DIR?}" \
--mode=train \
--params_override="{ type: retinanet, architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
```
### Train a custom RetinaNet using the config file. ### Train a custom RetinaNet using the config file.
...@@ -163,6 +179,24 @@ so the checkpoints are not compatible. ...@@ -163,6 +179,24 @@ so the checkpoints are not compatible.
We will unify the implementation soon. We will unify the implementation soon.
### Train a SpineNet-49 based Mask R-CNN.
```bash
TPU_NAME="<your GCP TPU name>"
MODEL_DIR="<path to the directory to store model files>"
TRAIN_FILE_PATTERN="<path to the TFRecord training data>"
EVAL_FILE_PATTERN="<path to the TFRecord validation data>"
VAL_JSON_FILE="<path to the validation annotation JSON file>"
python3 ~/models/official/vision/detection/main.py \
--strategy_type=tpu \
--tpu="${TPU_NAME?}" \
--model_dir="${MODEL_DIR?}" \
--mode=train \
--model=mask_rcnn \
--params_override="{architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
```
### Train a custom Mask R-CNN using the config file. ### Train a custom Mask R-CNN using the config file.
First, create a YAML config file, e.g. *my_maskrcnn.yaml*. First, create a YAML config file, e.g. *my_maskrcnn.yaml*.
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
BACKBONES = [ BACKBONES = [
'resnet', 'resnet',
'spinenet',
] ]
MULTILEVEL_FEATURES = [ MULTILEVEL_FEATURES = [
'fpn', 'fpn',
'identity',
] ]
# pylint: disable=line-too-long # pylint: disable=line-too-long
...@@ -118,6 +120,9 @@ BASE_CFG = { ...@@ -118,6 +120,9 @@ BASE_CFG = {
'resnet': { 'resnet': {
'resnet_depth': 50, 'resnet_depth': 50,
}, },
'spinenet': {
'model_id': '49',
},
'fpn': { 'fpn': {
'fpn_feat_dims': 256, 'fpn_feat_dims': 256,
'use_separable_conv': False, 'use_separable_conv': False,
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -23,6 +23,7 @@ from official.vision.detection.modeling.architecture import heads ...@@ -23,6 +23,7 @@ from official.vision.detection.modeling.architecture import heads
from official.vision.detection.modeling.architecture import identity from official.vision.detection.modeling.architecture import identity
from official.vision.detection.modeling.architecture import nn_ops from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.modeling.architecture import resnet from official.vision.detection.modeling.architecture import resnet
from official.vision.detection.modeling.architecture import spinenet
def norm_activation_generator(params): def norm_activation_generator(params):
...@@ -42,6 +43,9 @@ def backbone_generator(params): ...@@ -42,6 +43,9 @@ def backbone_generator(params):
activation=params.norm_activation.activation, activation=params.norm_activation.activation,
norm_activation=norm_activation_generator( norm_activation=norm_activation_generator(
params.norm_activation)) params.norm_activation))
elif params.architecture.backbone == 'spinenet':
spinenet_params = params.spinenet
backbone_fn = spinenet.SpineNetBuilder(model_id=spinenet_params.model_id)
else: else:
raise ValueError('Backbone model `{}` is not supported.' raise ValueError('Backbone model `{}` is not supported.'
.format(params.architecture.backbone)) .format(params.architecture.backbone))
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for neural networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Vision')
class ResidualBlock(tf.keras.layers.Layer):
"""A residual block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A residual block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ResidualBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=1,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(ResidualBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ResidualBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
return self._activation_fn(x + shortcut)
@tf.keras.utils.register_keras_serializable(package='Vision')
class BottleneckBlock(tf.keras.layers.Layer):
"""A standard bottleneck block."""
def __init__(self,
filters,
strides,
use_projection=False,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""A standard bottleneck block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(BottleneckBlock, self).__init__(**kwargs)
self._filters = filters
self._strides = strides
self._use_projection = use_projection
self._use_sync_bn = use_sync_bn
self._activation = activation
self._kernel_initializer = kernel_initializer
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
self._activation_fn = tf_utils.get_activation(activation)
def build(self, input_shape):
if self._use_projection:
self._shortcut = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=self._strides,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv1 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm1 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv2 = tf.keras.layers.Conv2D(
filters=self._filters,
kernel_size=3,
strides=self._strides,
padding='same',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm2 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
self._conv3 = tf.keras.layers.Conv2D(
filters=self._filters * 4,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)
self._norm3 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)
super(BottleneckBlock, self).build(input_shape)
def get_config(self):
config = {
'filters': self._filters,
'strides': self._strides,
'use_projection': self._use_projection,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(BottleneckBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
shortcut = inputs
if self._use_projection:
shortcut = self._shortcut(shortcut)
shortcut = self._norm0(shortcut)
x = self._conv1(inputs)
x = self._norm1(x)
x = self._activation_fn(x)
x = self._conv2(x)
x = self._norm2(x)
x = self._activation_fn(x)
x = self._conv3(x)
x = self._norm3(x)
return self._activation_fn(x + shortcut)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of SpineNet model.
X. Du, T-Y. Lin, P. Jin, G. Ghiasi, M. Tan, Y. Cui, Q. V. Le, X. Song
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
https://arxiv.org/abs/1912.05027
"""
import math
from absl import logging
import tensorflow as tf
from tensorflow.python.keras import backend
from official.modeling import tf_utils
from official.vision.detection.modeling.architecture import nn_blocks
layers = tf.keras.layers
FILTER_SIZE_MAP = {
1: 32,
2: 64,
3: 128,
4: 256,
5: 256,
6: 256,
7: 256,
}
# The fixed SpineNet architecture discovered by NAS.
# Each element represents a specification of a building block:
# (block_level, block_fn, (input_offset0, input_offset1), is_output).
SPINENET_BLOCK_SPECS = [
(2, 'bottleneck', (0, 1), False),
(4, 'residual', (0, 1), False),
(3, 'bottleneck', (2, 3), False),
(4, 'bottleneck', (2, 4), False),
(6, 'residual', (3, 5), False),
(4, 'bottleneck', (3, 5), False),
(5, 'residual', (6, 7), False),
(7, 'residual', (6, 8), False),
(5, 'bottleneck', (8, 9), False),
(5, 'bottleneck', (8, 10), False),
(4, 'bottleneck', (5, 10), True),
(3, 'bottleneck', (4, 10), True),
(5, 'bottleneck', (7, 12), True),
(7, 'bottleneck', (5, 14), True),
(6, 'bottleneck', (12, 14), True),
]
SCALING_MAP = {
'49S': {
'endpoints_num_filters': 128,
'filter_size_scale': 0.65,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'49': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 1,
},
'96': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 0.5,
'block_repeats': 2,
},
'143': {
'endpoints_num_filters': 256,
'filter_size_scale': 1.0,
'resample_alpha': 1.0,
'block_repeats': 3,
},
'190': {
'endpoints_num_filters': 512,
'filter_size_scale': 1.3,
'resample_alpha': 1.0,
'block_repeats': 4,
},
}
class BlockSpec(object):
"""A container class that specifies the block configuration for SpineNet."""
def __init__(self, level, block_fn, input_offsets, is_output):
self.level = level
self.block_fn = block_fn
self.input_offsets = input_offsets
self.is_output = is_output
def build_block_specs(block_specs=None):
"""Builds the list of BlockSpec objects for SpineNet."""
if not block_specs:
block_specs = SPINENET_BLOCK_SPECS
logging.info('Building SpineNet block specs: %s', block_specs)
return [BlockSpec(*b) for b in block_specs]
@tf.keras.utils.register_keras_serializable(package='Vision')
class SpineNet(tf.keras.Model):
"""Class to build SpineNet models."""
def __init__(self,
input_specs=tf.keras.layers.InputSpec(shape=[None, 640, 640, 3]),
min_level=3,
max_level=7,
block_specs=build_block_specs(),
endpoints_num_filters=256,
resample_alpha=0.5,
block_repeats=1,
filter_size_scale=1.0,
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
**kwargs):
"""SpineNet model."""
self._min_level = min_level
self._max_level = max_level
self._block_specs = block_specs
self._endpoints_num_filters = endpoints_num_filters
self._resample_alpha = resample_alpha
self._block_repeats = block_repeats
self._filter_size_scale = filter_size_scale
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if activation == 'relu':
self._activation = tf.nn.relu
elif activation == 'swish':
self._activation = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
self._init_block_fn = 'bottleneck'
self._num_init_blocks = 2
if use_sync_bn:
self._norm = layers.experimental.SyncBatchNormalization
else:
self._norm = layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
# Build SpineNet.
inputs = tf.keras.Input(shape=input_specs.shape[1:])
net = self._build_stem(inputs=inputs)
net = self._build_scale_permuted_network(
net=net, input_width=input_specs.shape[1])
net = self._build_endpoints(net=net)
super(SpineNet, self).__init__(inputs=inputs, outputs=net)
def _block_group(self,
inputs,
filters,
strides,
block_fn_cand,
block_repeats=1,
name='block_group'):
"""Creates one group of blocks for the SpineNet model."""
block_fn_candidates = {
'bottleneck': nn_blocks.BottleneckBlock,
'residual': nn_blocks.ResidualBlock,
}
block_fn = block_fn_candidates[block_fn_cand]
_, _, _, num_filters = inputs.get_shape().as_list()
if block_fn_cand == 'bottleneck':
use_projection = not (num_filters == (filters * 4) and strides == 1)
else:
use_projection = not (num_filters == filters and strides == 1)
x = block_fn(
filters=filters,
strides=strides,
use_projection=use_projection,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
for _ in range(1, block_repeats):
x = block_fn(
filters=filters,
strides=1,
use_projection=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
x)
return tf.identity(x, name=name)
def _build_stem(self, inputs):
"""Build SpineNet stem."""
x = layers.Conv2D(
filters=64,
kernel_size=7,
strides=2,
use_bias=False,
padding='same',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
net = []
# Build the initial level 2 blocks.
for i in range(self._num_init_blocks):
x = self._block_group(
inputs=x,
filters=int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
strides=1,
block_fn_cand=self._init_block_fn,
block_repeats=self._block_repeats,
name='stem_block_{}'.format(i + 1))
net.append(x)
return net
def _build_scale_permuted_network(self,
net,
input_width,
weighted_fusion=False):
"""Build scale-permuted network."""
net_sizes = [int(math.ceil(input_width / 2**2))] * len(net)
net_block_fns = [self._init_block_fn] * len(net)
num_outgoing_connections = [0] * len(net)
endpoints = {}
for i, block_spec in enumerate(self._block_specs):
# Find out specs for the target block.
target_width = int(math.ceil(input_width / 2**block_spec.level))
target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
self._filter_size_scale)
target_block_fn = block_spec.block_fn
# Resample then merge input0 and input1.
parents = []
input0 = block_spec.input_offsets[0]
input1 = block_spec.input_offsets[1]
x0 = self._resample_with_alpha(
inputs=net[input0],
input_width=net_sizes[input0],
input_block_fn=net_block_fns[input0],
target_width=target_width,
target_num_filters=target_num_filters,
target_block_fn=target_block_fn,
alpha=self._resample_alpha)
parents.append(x0)
num_outgoing_connections[input0] += 1
x1 = self._resample_with_alpha(
inputs=net[input1],
input_width=net_sizes[input1],
input_block_fn=net_block_fns[input1],
target_width=target_width,
target_num_filters=target_num_filters,
target_block_fn=target_block_fn,
alpha=self._resample_alpha)
parents.append(x1)
num_outgoing_connections[input1] += 1
# Merge 0 outdegree blocks to the output block.
if block_spec.is_output:
for j, (j_feat,
j_connections) in enumerate(zip(net, num_outgoing_connections)):
if j_connections == 0 and (j_feat.shape[2] == target_width and
j_feat.shape[3] == x0.shape[3]):
parents.append(j_feat)
num_outgoing_connections[j] += 1
# pylint: disable=g-direct-tensorflow-import
if weighted_fusion:
dtype = parents[0].dtype
parent_weights = [
tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
i, j)), dtype=dtype)) for j in range(len(parents))]
weights_sum = tf.add_n(parent_weights)
parents = [
parents[i] * parent_weights[i] / (weights_sum + 0.0001)
for i in range(len(parents))
]
# Fuse all parent nodes then build a new block.
x = tf_utils.get_activation(self._activation)(tf.add_n(parents))
x = self._block_group(
inputs=x,
filters=target_num_filters,
strides=1,
block_fn_cand=target_block_fn,
block_repeats=self._block_repeats,
name='scale_permuted_block_{}'.format(i + 1))
net.append(x)
net_sizes.append(target_width)
net_block_fns.append(target_block_fn)
num_outgoing_connections.append(0)
# Save output feats.
if block_spec.is_output:
if block_spec.level in endpoints:
raise ValueError('Duplicate feats found for output level {}.'.format(
block_spec.level))
if (block_spec.level < self._min_level or
block_spec.level > self._max_level):
raise ValueError('Output level is out of range [{}, {}]'.format(
self._min_level, self._max_level))
endpoints[block_spec.level] = x
return endpoints
def _build_endpoints(self, net):
"""Match filter size for endpoints before sharing conv layers."""
endpoints = {}
for level in range(self._min_level, self._max_level + 1):
x = layers.Conv2D(
filters=self._endpoints_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
net[level])
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
endpoints[level] = x
return endpoints
def _resample_with_alpha(self,
inputs,
input_width,
input_block_fn,
target_width,
target_num_filters,
target_block_fn,
alpha=0.5):
"""Match resolution and feature dimension."""
_, _, _, input_num_filters = inputs.get_shape().as_list()
if input_block_fn == 'bottleneck':
input_num_filters /= 4
new_num_filters = int(input_num_filters * alpha)
x = layers.Conv2D(
filters=new_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
inputs)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
# Spatial resampling.
if input_width > target_width:
x = layers.Conv2D(
filters=new_num_filters,
kernel_size=3,
strides=2,
padding='SAME',
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation)(x)
input_width /= 2
while input_width > target_width:
x = layers.MaxPool2D(pool_size=3, strides=2, padding='SAME')(x)
input_width /= 2
elif input_width < target_width:
scale = target_width // input_width
x = layers.UpSampling2D(size=(scale, scale))(x)
# Last 1x1 conv to match filter size.
if target_block_fn == 'bottleneck':
target_num_filters *= 4
x = layers.Conv2D(
filters=target_num_filters,
kernel_size=1,
strides=1,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer)(
x)
x = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
return x
class SpineNetBuilder(object):
"""SpineNet builder."""
def __init__(self,
model_id,
min_level=3,
max_level=7,
block_specs=build_block_specs(),
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001):
if model_id not in SCALING_MAP:
raise ValueError(
'SpineNet {} is not a valid architecture.'.format(model_id))
scaling_params = SCALING_MAP[model_id]
self._min_level = min_level
self._max_level = max_level
self._block_specs = block_specs
self._endpoints_num_filters = scaling_params['endpoints_num_filters']
self._resample_alpha = scaling_params['resample_alpha']
self._block_repeats = scaling_params['block_repeats']
self._filter_size_scale = scaling_params['filter_size_scale']
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._activation = activation
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
def __call__(self, inputs, is_training=None):
with backend.get_graph().as_default():
model = SpineNet(
min_level=self._min_level,
max_level=self._max_level,
block_specs=self._block_specs,
endpoints_num_filters=self._endpoints_num_filters,
resample_alpha=self._resample_alpha,
block_repeats=self._block_repeats,
filter_size_scale=self._filter_size_scale,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=self._activation,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)
return model(inputs)
# UNet 3D Model
This repository contains TensorFlow 2.x implementation for 3D Unet model
[[1]](#1) as well as instructions for producing the data for training and
evaluation.
Furthermore, this implementation also includes use of spatial partitioning
[[2]](#2) for TPU's to leverage high resolution images for training.
## Contents
* [Contents](#contents)
* [Prerequsites](#prerequsites)
* [Setup](#setup)
* [Data Preparation](#data-preparation)
* [Training](#data-preparation)
* [Train with Spatial Partition](#train-with-spatial-partition)
* [Evaluation](#evaluation)
* [References](#references)
## Prerequsites
To use high resolution image data, spatial partition should be used to avoid
prevent out of memory issues. This is currently only supported with TPU's. To
use TPU's for training, in Google Cloud console, please run the following
command to create cloud TPU VM.
```shell
ctpu up -name=[tpu_name] -tf-version=nightly -tpu-size=v3-8 -zone=us-central1-b
```
## Setup
Before running any binary, please install necessary packages on cloud VM.
```shell
pip install -r requirements.tx
```
## Data Preparation
This software uses TFRecords as input. We provide example scripts to convert
Numpy (.npy) files or NIfTI-1 (.nii) files to TFRecords, using the Liver Tumor
Segmentation (LiTS) dataset (Christ et al.
https://competitions.codalab.org/competitions/17094). You can download the
dataset by registering on the competition website.
**Example**:
```shell
cd data_preprocess
# Change input_path and output_path in convert_lits_nii_to_npy.py
# Then run the script to convert nii to npy.
python convert_lits_nii_to_npy.py
# Convert npy files to TFRecords.
python convert_lits.py \
--image_file_pattern=Downloads/.../volume-{}.npy \
--label_file_pattern=Downloads/.../segmentation-{}.npy \
--output_path=Downloads/...
```
## Training
Working configs on TPU V3-8:
+ TF 2.2, train_batch_size=16, use_batch_norm=true, dtype='bfloat16' or
'float16', spatial partition not used.
+ tf-nightly, train_batch_size=32, use_batch_norm=true, dtype='bfloat16',
spatial partition used.
The following example shows how to train volumic UNet on TPU v3-8. The loss is
*adaptive_dice32*. The training batch size is 32. For detail config, refer to
`unet_config.py` and example config file shown below.
**Example**:
```shell
DATA_BUCKET=<GS bucket for data>
TRAIN_FILES="${DATA_BUCKET}/tfrecords/trainbox*.tfrecord"
VAL_FILES="${DATA_BUCKET}/tfrecords/validationbox*.tfrecord"
MODEL_BUCKET=<GS bucket for model checkpoints>
EXP_NAME=unet_20190610_dice_t1
python unet_main.py \
--distribution_strategy=<"mirrored" or "tpu">
--num_gpus=<'number of GPUs to use if using mirrored strategy'>
--tpu=<TPU name> \
--model_dir="gs://${MODEL_BUCKET}/models/${EXP_NAME}" \
--training_file_pattern="${TRAIN_FILES}" \
--eval_file_pattern="${VAL_FILES}" \
--steps_per_loop=10 \
--mode=train \
--config_file="./configs/cloud/v3-8_128x128x128_ce.yaml" \
```
The following script example is for running evaluation on TPU v3-8.
Configurations such as `train_batch_size`, `train_steps`, `eval_batch_size` and
`eval_item_count` are defined in the configuration file passed as
`config_file`flag. It is only one line change from previous script: changes the
`mode` flag to "eval".
### Train with Spatial Partition
The following example specifies spatial partition with the
"--input_partition_dims" in the config file. For example, setting
`input_partition_dims: [1, 16, 1, 1, 1]` in the config_file will split
the image into 16 ways in first (width) dimension. The first dimension
(set to 1) is the batch dimension.
**Example: Train with 16-way spatial partition**:
```shell
DATA_BUCKET=<GS bucket for data>
TRAIN_FILES="${DATA_BUCKET}/tfrecords/trainbox*.tfrecord"
VAL_FILES="${DATA_BUCKET}/tfrecords/validationbox*.tfrecord"
MODEL_BUCKET=<GS bucket for model checkpoints>
EXP_NAME=unet_20190610_dice_t1
python unet_main.py \
--distribution_strategy=<"mirrored" or "tpu">
--num_gpus=<'number of GPUs to use if using mirrored strategy'>
--tpu=<TPU name> \
--model_dir="gs://${MODEL_BUCKET}/models/${EXP_NAME}" \
--training_file_pattern="${TRAIN_FILES}" \
--eval_file_pattern="${VAL_FILES}" \
--steps_per_loop=10 \
--mode=train \
--config_file="./configs/cloud/v3-8_128x128x128_ce.yaml"
```
**Example: Example config file with 16-way spatial partition**:
```
train_steps: 3000
loss: 'adaptive_dice32'
train_batch_size: 8
eval_batch_size: 8
use_index_label_in_train: false
input_partition_dims: [1,16,1,1,1]
input_image_size: [256,256,256]
dtype: 'bfloat16'
label_dtype: 'float32'
train_item_count: 5400
eval_item_count: 1674
```
## Evaluation
```shell
DATA_BUCKET=<GS bucket for data>
TRAIN_FILES="${DATA_BUCKET}/tfrecords/trainbox*.tfrecord"
VAL_FILES="${DATA_BUCKET}/tfrecords/validationbox*.tfrecord"
MODEL_BUCKET=<GS bucket for model checkpoints>
EXP_NAME=unet_20190610_dice_t1
python unet_main.py \
--distribution_strategy=<"mirrored" or "tpu">
--num_gpus=<'number of GPUs to use if using mirrored strategy'>
--tpu=<TPU name> \
--model_dir="gs://${MODEL_BUCKET}/models/${EXP_NAME}" \
--training_file_pattern="${TRAIN_FILES}" \
--eval_file_pattern="${VAL_FILES}" \
--steps_per_loop=10 \
--mode="eval" \
--config_file="./configs/cloud/v3-8_128x128x128_ce.yaml"
```
## License
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
This project is licensed under the terms of the **Apache License 2.0**.
## References
<a id="1">[1]</a> Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp,
Thomas Brox, Olaf Ronneberger "3D U-Net: Learning Dense Volumetric Segmentation
from Sparse Annotation": https://arxiv.org/abs/1606.06650. (MICCAI 2016).
<a id="2">[2]</a> Le Hou, Youlong Cheng, Noam Shazeer, Niki Parmar, Yeqing Li,
Panagiotis Korfiatis, Travis M. Drucker, Daniel J. Blezek, Xiaodan Song "High
Resolution Medical Image Analysis with Spatial Partitioning":
https://arxiv.org/abs/1810.04805.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts raw LiTS numpy data to TFRecord.
The file is forked from:
https://github.com/tensorflow/tpu/blob/master/models/official/unet3d/data_preprocess/convert_lits.py
"""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import os
from absl import app
from absl import flags
from absl import logging
import numpy as np
from PIL import Image
from scipy import ndimage
import tensorflow.google.compat.v1 as tf
flags.DEFINE_string("image_file_pattern", None,
"path pattern to an input image npy file.")
flags.DEFINE_string("label_file_pattern", None,
"path pattern to an input label npy file.")
flags.DEFINE_string("output_path", None, "path to output TFRecords.")
flags.DEFINE_boolean("crop_liver_region", True,
"whether to crop liver region out.")
flags.DEFINE_boolean("apply_data_aug", False,
"whether to apply data augmentation.")
flags.DEFINE_integer("shard_start", 0,
"start with volume-${shard_start}.npy.")
flags.DEFINE_integer("shard_stride", 1,
"this process will convert "
"volume-${shard_start + n * shard_stride}.npy for all n.")
flags.DEFINE_integer("output_size", 128,
"output, cropped size along x, y, and z.")
flags.DEFINE_integer("resize_size", 192,
"size along x, y, and z before cropping.")
FLAGS = flags.FLAGS
def to_1hot(label):
per_class = []
for classes in range(3):
per_class.append((label == classes)[..., np.newaxis])
label = np.concatenate(per_class, axis=-1).astype(label.dtype)
return label
def save_to_tfrecord(image, label, idx, im_id, output_path,
convert_label_to_1hot):
"""Save to TFRecord."""
if convert_label_to_1hot:
label = to_1hot(label)
d_feature = {}
d_feature["image/ct_image"] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image.reshape([-1]).tobytes()]))
d_feature["image/label"] = tf.train.Feature(
bytes_list=tf.train.BytesList(value=[label.reshape([-1]).tobytes()]))
example = tf.train.Example(features=tf.train.Features(feature=d_feature))
serialized = example.SerializeToString()
result_file = os.path.join(
output_path, "instance-{}-{}.tfrecords".format(im_id, idx))
options = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP)
with tf.python_io.TFRecordWriter(result_file, options=options) as w:
w.write(serialized)
def intensity_change(im):
"""Color augmentation."""
if np.random.rand() < 0.1:
return im
# Randomly scale color.
sigma = 0.05
truncate_rad = 0.1
im *= np.clip(np.random.normal(1.0, sigma),
1.0 - truncate_rad, 1.0 + truncate_rad)
return im
def rand_crop_liver(image, label, res_s, out_s,
apply_data_aug, augment_times=54):
"""Crop image and label; Randomly change image intensity.
Randomly crop image and label around liver.
Args:
image: 3D numpy array.
label: 3D numpy array.
res_s: resized size of image and label.
out_s: output size of random crops.
apply_data_aug: whether to apply data augmentation.
augment_times: the number of times to randomly crop and augment data.
Yields:
croped and augmented image and label.
"""
if image.shape != (res_s, res_s, res_s) or \
label.shape != (res_s, res_s, res_s):
logging.info("Unexpected shapes. "
"image.shape: %s, label.shape: %s",
image.shape, label.shape)
return
rough_liver_label = 1
x, y, z = np.where(label == rough_liver_label)
bbox_center = [(x.min() + x.max()) // 2,
(y.min() + y.max()) // 2,
(z.min() + z.max()) // 2]
def in_range_check(c):
c = max(c, out_s // 2)
c = min(c, res_s - out_s // 2)
return c
for _ in range(augment_times):
rand_c = []
for c in bbox_center:
sigma = out_s // 6
truncate_rad = out_s // 4
c += np.clip(np.random.randn() * sigma, -truncate_rad, truncate_rad)
rand_c.append(int(in_range_check(c)))
image_aug = image[rand_c[0] - out_s // 2:rand_c[0] + out_s // 2,
rand_c[1] - out_s // 2:rand_c[1] + out_s // 2,
rand_c[2] - out_s // 2:rand_c[2] + out_s // 2].copy()
label_aug = label[rand_c[0] - out_s // 2:rand_c[0] + out_s // 2,
rand_c[1] - out_s // 2:rand_c[1] + out_s // 2,
rand_c[2] - out_s // 2:rand_c[2] + out_s // 2].copy()
if apply_data_aug:
image_aug = intensity_change(image_aug)
yield image_aug, label_aug
def rand_crop_whole_ct(image, label, res_s, out_s,
apply_data_aug, augment_times=2):
"""Crop image and label; Randomly change image intensity.
Randomly crop image and label.
Args:
image: 3D numpy array.
label: 3D numpy array.
res_s: resized size of image and label.
out_s: output size of random crops.
apply_data_aug: whether to apply data augmentation.
augment_times: the number of times to randomly crop and augment data.
Yields:
croped and augmented image and label.
"""
if image.shape != (res_s, res_s, res_s) or \
label.shape != (res_s, res_s, res_s):
logging.info("Unexpected shapes. "
"image.shape: %s, label.shape: %s",
image.shape, label.shape)
return
if not apply_data_aug:
# Do not augment data.
idx = (res_s - out_s) // 2
image = image[idx:idx + out_s, idx:idx + out_s, idx:idx + out_s]
label = label[idx:idx + out_s, idx:idx + out_s, idx:idx + out_s]
yield image, label
else:
cut = res_s - out_s
for _ in range(augment_times):
for i in [0, cut // 2, cut]:
for j in [0, cut // 2, cut]:
for k in [0, cut // 2, cut]:
image_aug = image[i:i + out_s, j:j + out_s, k:k + out_s].copy()
label_aug = label[i:i + out_s, j:j + out_s, k:k + out_s].copy()
image_aug = intensity_change(image_aug)
yield image_aug, label_aug
def resize_3d_image_nearest_interpolation(im, res_s):
"""Resize 3D image, but with nearest interpolation."""
new_shape = [res_s, im.shape[1], im.shape[2]]
ret0 = np.zeros(new_shape, dtype=im.dtype)
for i in range(im.shape[2]):
im_slice = np.array(Image.fromarray(im[..., i]).resize(
(im.shape[1], res_s), resample=Image.NEAREST))
ret0[..., i] = im_slice
new_shape = [res_s, res_s, res_s]
ret = np.zeros(new_shape, dtype=im.dtype)
for i in range(res_s):
im_slice = np.array(Image.fromarray(ret0[i, ...]).resize(
(res_s, res_s), resample=Image.NEAREST))
ret[i, ...] = im_slice
return ret
def process_one_file(image_path, label_path, im_id,
output_path, res_s, out_s,
crop_liver_region, apply_data_aug):
"""Convert one npy file."""
with tf.gfile.Open(image_path, "rb") as f:
image = np.load(f)
with tf.gfile.Open(label_path, "rb") as f:
label = np.load(f)
image = ndimage.zoom(image, [float(res_s) / image.shape[0],
float(res_s) / image.shape[1],
float(res_s) / image.shape[2]])
label = resize_3d_image_nearest_interpolation(label.astype(np.uint8),
res_s).astype(np.float32)
if crop_liver_region:
for idx, (image_aug, label_aug) in enumerate(rand_crop_liver(
image, label, res_s, out_s, apply_data_aug)):
save_to_tfrecord(image_aug, label_aug, idx, im_id, output_path,
convert_label_to_1hot=True)
else: # not crop_liver_region
# If we output the entire CT scan (crop_liver_region=False),
# do not convert_label_to_1hot to save storage.
for idx, (image_aug, label_aug) in enumerate(rand_crop_whole_ct(
image, label, res_s, out_s, apply_data_aug)):
save_to_tfrecord(image_aug, label_aug, idx, im_id, output_path,
convert_label_to_1hot=False)
def main(argv):
del argv
output_path = FLAGS.output_path
res_s = FLAGS.resize_size
out_s = FLAGS.output_size
crop_liver_region = FLAGS.crop_liver_region
apply_data_aug = FLAGS.apply_data_aug
for im_id in range(FLAGS.shard_start, 1000000, FLAGS.shard_stride):
image_path = FLAGS.image_file_pattern.format(im_id)
label_path = FLAGS.label_file_pattern.format(im_id)
if not tf.gfile.Exists(image_path):
logging.info("Reached the end. Image does not exist: %s. "
"Process finish.", image_path)
break
process_one_file(image_path, label_path, im_id,
output_path, res_s, out_s,
crop_liver_region, apply_data_aug)
if __name__ == "__main__":
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts .nii files in LiTS dataset to .npy files.
This script should be run just once before running convert_lits.py.
The file is forked from:
https://github.com/tensorflow/tpu/blob/master/models/official/unet3d/data_preprocess/convert_lits_nii_to_npy.py
"""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import glob
import multiprocessing
import os
import nibabel as nib
import numpy as np
num_processes = 2
input_path = "Downloads/LiTS/Train/" # where the .nii files are.
output_path = "Downloads/LiTS/Train_np/" # where you want to put the npy files.
def process_one_file(image_path):
"""Convert one nii file to npy."""
im_id = os.path.basename(image_path).split("volume-")[1].split(".nii")[0]
label_path = image_path.replace("volume-", "segmentation-")
image = nib.load(image_path).get_data().astype(np.float32)
label = nib.load(label_path).get_data().astype(np.float32)
print("image shape: {}, dtype: {}".format(image.shape, image.dtype))
print("label shape: {}, dtype: {}".format(label.shape, label.dtype))
np.save(os.path.join(output_path, "volume-{}.npy".format(im_id)), image)
np.save(os.path.join(output_path, "segmentation-{}.npy".format(im_id)), label)
nii_dir = os.path.join(input_path, "volume-*")
p = multiprocessing.Pool(num_processes)
p.map(process_one_file, glob.glob(nii_dir))
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config to train UNet."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
UNET_CONFIG = {
# Place holder for tpu configs.
'tpu_config': {},
'model_dir': '',
'training_file_pattern': None,
'eval_file_pattern': None,
# The input files are GZip compressed and need decompression.
'compressed_input': True,
'dtype': 'bfloat16',
'label_dtype': 'float32',
'train_batch_size': 8,
'eval_batch_size': 8,
'predict_batch_size': 8,
'train_epochs': 20,
'train_steps': 1000,
'eval_steps': 10,
'num_steps_per_eval': 100,
'min_eval_interval': 180,
'eval_timeout': None,
'optimizer': 'adam',
'momentum': 0.9,
# Spatial dimension of input image.
'input_image_size': [128, 128, 128],
# Number of channels of the input image.
'num_channels': 1,
# Spatial partition dimensions.
'input_partition_dims': None,
# Use deconvolution to upsample, otherwise upsampling.
'deconvolution': True,
# Number of areas i need to segment
'num_classes': 3,
# Number of filters used by the architecture
'num_base_filters': 32,
# Depth of the network
'depth': 4,
# Dropout values to use across the network
'dropout_rate': 0.5,
# Number of levels that contribute to the output.
'num_segmentation_levels': 2,
# Use batch norm.
'use_batch_norm': True,
'init_learning_rate': 0.1,
# learning rate decay steps.
'lr_decay_steps': 100,
# learning rate decay rate.
'lr_decay_rate': 0.5,
# Data format, 'channels_last' and 'channels_first'
'data_format': 'channels_last',
# Use class index for training. Otherwise, use one-hot encoding.
'use_index_label_in_train': False,
# e.g. softmax cross entropy, adaptive_dice32
'loss': 'adaptive_dice32',
}
UNET_RESTRICTIONS = []
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Defines input_fn of TF2 UNet-3D model."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import tensorflow as tf
class BaseInput(object):
"""Input function for 3D Unet model."""
def __init__(self, file_pattern, params, is_training):
self._params = params
self._file_pattern = file_pattern
self._is_training = is_training
self._parser_fn = self.create_parser_fn(params)
if params.compressed_input:
self._dataset_fn = functools.partial(
tf.data.TFRecordDataset, compression_type='GZIP')
else:
self._dataset_fn = tf.data.TFRecordDataset
def create_parser_fn(self, params):
"""Create parse fn to extract tensors from tf.Example."""
def _parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.io.parse_example(
serialized=[serialized_example],
features={
'image/encoded': tf.io.VarLenFeature(dtype=tf.float32),
'image/segmentation/mask': tf.io.VarLenFeature(dtype=tf.float32),
})
image = features['image/encoded']
if isinstance(image, tf.SparseTensor):
image = tf.sparse.to_dense(image)
gt_mask = features['image/segmentation/mask']
if isinstance(gt_mask, tf.SparseTensor):
gt_mask = tf.sparse.to_dense(gt_mask)
image_size, label_size = self.get_input_shapes(params)
image = tf.reshape(image, image_size)
gt_mask = tf.reshape(gt_mask, label_size)
image = tf.cast(image, dtype=params.dtype)
gt_mask = tf.cast(gt_mask, dtype=params.dtype)
return image, gt_mask
return _parser
def get_input_shapes(self, params):
image_size = params.input_image_size + [params.num_channels]
label_size = params.input_image_size + [params.num_classes]
return image_size, label_size
def __call__(self, input_pipeline_context=None):
"""Generates features and labels for training or evaluation.
This uses the input pipeline based approach using file name queue
to read data so that entire data is not loaded in memory.
Args:
input_pipeline_context: Context used by distribution strategy to
shard dataset across workers.
Returns:
tf.data.Dataset
"""
params = self._params
batch_size = (
params.train_batch_size
if self._is_training else params.eval_batch_size)
dataset = tf.data.Dataset.list_files(
self._file_pattern, shuffle=self._is_training)
# Shard dataset when there are more than 1 workers in training.
if input_pipeline_context:
batch_size = input_pipeline_context.get_per_replica_batch_size(batch_size)
if input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.apply(
tf.data.experimental.parallel_interleave(
lambda file_name: self._dataset_fn(file_name).prefetch(1),
cycle_length=32,
sloppy=self._is_training))
if self._is_training:
dataset = dataset.shuffle(64)
# Parses the fetched records to input tensors for model function.
dataset = dataset.map(self._parser_fn, tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
class LiverInput(BaseInput):
"""Input function of Liver Segmentation data set."""
def create_parser_fn(self, params):
"""Create parse fn to extract tensors from tf.Example."""
def _decode_liver_example(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = {}
features['image/ct_image'] = tf.io.FixedLenFeature([], tf.string)
features['image/label'] = tf.io.FixedLenFeature([], tf.string)
parsed = tf.io.parse_single_example(
serialized=serialized_example, features=features)
# Here, assumes the `image` is normalized to [0, 1] of type float32 and
# the `label` is a binary matrix, whose last dimension is one_hot encoded
# labels.
# The dtype of `label` can be either float32 or int64.
image = tf.io.decode_raw(parsed['image/ct_image'],
tf.as_dtype(tf.float32))
label = tf.io.decode_raw(parsed['image/label'],
tf.as_dtype(params.label_dtype))
image_size = params.input_image_size + [params.num_channels]
image = tf.reshape(image, image_size)
label_size = params.input_image_size + [params.num_classes]
label = tf.reshape(label, label_size)
if self._is_training and params.use_index_label_in_train:
# Use class index for labels and remove the channel dim (#channels=1).
channel_dim = -1
label = tf.argmax(input=label, axis=channel_dim, output_type=tf.int32)
image = tf.cast(image, dtype=params.dtype)
label = tf.cast(label, dtype=params.dtype)
# TPU doesn't support tf.int64 well, use tf.int32 directly.
if label.dtype == tf.int64:
label = tf.cast(label, dtype=tf.int32)
return image, label
return _decode_liver_example
def get_input_shapes(self, params):
image_size = params.input_image_size + [params.num_channels]
if self._is_training and params.use_index_label_in_train:
label_size = params.input_image_size
else:
label_size = params.input_image_size + [params.num_classes]
return image_size, label_size
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Training script for UNet-3D."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import os
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
from official.modeling.hyperparams import params_dict
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.segmentation import unet_config
from official.vision.segmentation import unet_data
from official.vision.segmentation import unet_metrics
from official.vision.segmentation import unet_model as unet_model_lib
def define_unet3d_flags():
"""Defines flags for training 3D Unet."""
hyperparams_flags.initialize_common_flags()
flags.DEFINE_enum(
'distribution_strategy', 'tpu', ['tpu', 'mirrored'],
'Distribution Strategy type to use for training. `tpu` uses TPUStrategy '
'for running on TPUs, `mirrored` uses GPUs with single host.')
flags.DEFINE_integer(
'steps_per_loop', 50,
'Number of steps to execute in a loop for performance optimization.')
flags.DEFINE_integer('checkpoint_interval', 100,
'Minimum step interval between two checkpoints.')
flags.DEFINE_integer('epochs', 10, 'Number of epochs to run training.')
flags.DEFINE_string(
'gcp_project',
default=None,
help='Project name for the Cloud TPU-enabled project. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
flags.DEFINE_string(
'eval_checkpoint_dir',
default=None,
help='Directory for reading checkpoint file when `mode` == `eval`.')
flags.DEFINE_multi_integer(
'input_partition_dims', [1],
'A list that describes the partition dims for all the tensors.')
flags.DEFINE_string(
'mode', 'train', 'Mode to run: train or eval or train_and_eval '
'(default: train)')
flags.DEFINE_string('training_file_pattern', None,
'Location of the train data.')
flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
flags.DEFINE_float('lr_init_value', 0.0001, 'Initial learning rate.')
flags.DEFINE_float('lr_decay_rate', 0.9, 'Learning rate decay rate.')
flags.DEFINE_integer('lr_decay_steps', 100, 'Learning rate decay steps.')
def save_params(params):
"""Save parameters to config files if model_dir is defined."""
model_dir = params.model_dir
assert model_dir is not None
if not tf.io.gfile.exists(model_dir):
tf.io.gfile.makedirs(model_dir)
file_name = os.path.join(model_dir, 'params.yaml')
params_dict.save_params_dict_to_yaml(params, file_name)
def extract_params(flags_obj):
"""Extract configuration parameters for training and evaluation."""
params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
unet_config.UNET_RESTRICTIONS)
params = params_dict.override_params_dict(
params, flags_obj.config_file, is_strict=False)
if flags_obj.training_file_pattern:
params.override({'training_file_pattern': flags_obj.training_file_pattern},
is_strict=True)
if flags_obj.eval_file_pattern:
params.override({'eval_file_pattern': flags_obj.eval_file_pattern},
is_strict=True)
train_epoch_steps = params.train_item_count // params.train_batch_size
eval_epoch_steps = params.eval_item_count // params.eval_batch_size
params.override(
{
'model_dir': flags_obj.model_dir,
'eval_checkpoint_dir': flags_obj.eval_checkpoint_dir,
'mode': flags_obj.mode,
'distribution_strategy': flags_obj.distribution_strategy,
'tpu': flags_obj.tpu,
'num_gpus': flags_obj.num_gpus,
'init_learning_rate': flags_obj.lr_init_value,
'lr_decay_rate': flags_obj.lr_decay_rate,
'lr_decay_steps': train_epoch_steps,
'train_epoch_steps': train_epoch_steps,
'eval_epoch_steps': eval_epoch_steps,
'steps_per_loop': flags_obj.steps_per_loop,
'epochs': flags_obj.epochs,
'checkpoint_interval': flags_obj.checkpoint_interval,
},
is_strict=False)
params.validate()
params.lock()
return params
def unet3d_callbacks(params, checkpoint_manager=None):
"""Custom callbacks during training."""
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=params.model_dir)
if checkpoint_manager:
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
return [tensorboard_callback, checkpoint_callback]
else:
return [tensorboard_callback]
def get_computation_shape_for_model_parallelism(input_partition_dims):
"""Return computation shape to be used for TPUStrategy spatial partition."""
num_logical_devices = np.prod(input_partition_dims)
if num_logical_devices == 1:
return [1, 1, 1, 1]
if num_logical_devices == 2:
return [1, 1, 1, 2]
if num_logical_devices == 4:
return [1, 2, 1, 2]
if num_logical_devices == 8:
return [2, 2, 1, 2]
if num_logical_devices == 16:
return [4, 2, 1, 2]
raise ValueError('Unsupported number of spatial partition configuration.')
def create_distribution_strategy(params):
"""Creates distribution strategy to use for computation."""
if params.input_partition_dims is not None:
if params.distribution_strategy != 'tpu':
raise ValueError('Spatial partitioning is only supported '
'for TPUStrategy.')
# When `input_partition_dims` is specified create custom TPUStrategy
# instance with computation shape for model parallelism.
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=params.tpu)
if params.tpu not in ('', 'local'):
tf.config.experimental_connect_to_cluster(resolver)
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
num_replicas = resolver.get_tpu_system_metadata().num_cores // np.prod(
params.input_partition_dims)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology,
num_replicas=num_replicas,
computation_shape=get_computation_shape_for_model_parallelism(
params.input_partition_dims))
return tf.distribute.experimental.TPUStrategy(
resolver, device_assignment=device_assignment)
return distribution_utils.get_distribution_strategy(
distribution_strategy=params.distribution_strategy,
tpu_address=params.tpu,
num_gpus=params.num_gpus)
def get_train_dataset(params, ctx=None):
"""Returns training dataset."""
return unet_data.LiverInput(
params.training_file_pattern, params, is_training=True)(
ctx)
def get_eval_dataset(params, ctx=None):
"""Returns evaluation dataset."""
return unet_data.LiverInput(
params.training_file_pattern, params, is_training=False)(
ctx)
def expand_1d(data):
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s."""
def _expand_single_1d_tensor(t):
if (isinstance(t, tf.Tensor) and isinstance(t.shape, tf.TensorShape) and
t.shape.rank == 1):
return tf.expand_dims(t, axis=-1)
return t
return tf.nest.map_structure(_expand_single_1d_tensor, data)
def train_step(train_fn, input_partition_dims, data):
"""The logic for one training step with spatial partitioning."""
# Keras expects rank 2 inputs. As so, expand single rank inputs.
data = expand_1d(data)
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
if input_partition_dims:
strategy = tf.distribute.get_strategy()
x = strategy.experimental_split_to_logical_devices(x, input_partition_dims)
y = strategy.experimental_split_to_logical_devices(y, input_partition_dims)
partitioned_data = tf.keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
return train_fn(partitioned_data)
def test_step(test_fn, input_partition_dims, data):
"""The logic for one testing step with spatial partitioning."""
# Keras expects rank 2 inputs. As so, expand single rank inputs.
data = expand_1d(data)
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
if input_partition_dims:
strategy = tf.distribute.get_strategy()
x = strategy.experimental_split_to_logical_devices(x, input_partition_dims)
y = strategy.experimental_split_to_logical_devices(y, input_partition_dims)
partitioned_data = tf.keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
return test_fn(partitioned_data)
def train(params, strategy, unet_model, train_input_fn, eval_input_fn):
"""Trains 3D Unet model."""
assert tf.distribute.has_strategy()
# Override Keras Model's train_step() and test_step() function so
# that inputs are spatially partitioned.
# Note that is `predict()` API is used, then `predict_step()` should also
# be overriden.
unet_model.train_step = functools.partial(train_step, unet_model.train_step,
params.input_partition_dims)
unet_model.test_step = functools.partial(test_step, unet_model.test_step,
params.input_partition_dims)
optimizer = unet_model_lib.create_optimizer(params.init_learning_rate, params)
loss_fn = unet_metrics.get_loss_fn(params.mode, params)
unet_model.compile(
loss=loss_fn,
optimizer=optimizer,
metrics=[unet_metrics.metric_accuracy],
experimental_steps_per_execution=params.steps_per_loop)
train_ds = strategy.experimental_distribute_datasets_from_function(
train_input_fn)
eval_ds = strategy.experimental_distribute_datasets_from_function(
eval_input_fn)
checkpoint = tf.train.Checkpoint(model=unet_model)
train_epoch_steps = params.train_item_count // params.train_batch_size
eval_epoch_steps = params.eval_item_count // params.eval_batch_size
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=params.model_dir,
max_to_keep=10,
step_counter=unet_model.optimizer.iterations,
checkpoint_interval=params.checkpoint_interval)
checkpoint_manager.restore_or_initialize()
train_result = unet_model.fit(
x=train_ds,
epochs=params.epochs,
steps_per_epoch=train_epoch_steps,
validation_data=eval_ds,
validation_steps=eval_epoch_steps,
callbacks=unet3d_callbacks(params, checkpoint_manager))
return train_result
def evaluate(params, strategy, unet_model, input_fn):
"""Reads from checkpoint and evaluate 3D Unet model."""
assert tf.distribute.has_strategy()
unet_model.compile(
metrics=[unet_metrics.metric_accuracy],
experimental_steps_per_execution=params.steps_per_loop)
# Override test_step() function so that inputs are spatially partitioned.
unet_model.test_step = functools.partial(test_step, unet_model.test_step,
params.input_partition_dims)
# Load checkpoint for evaluation.
checkpoint = tf.train.Checkpoint(model=unet_model)
checkpoint_path = tf.train.latest_checkpoint(params.eval_checkpoint_dir)
status = checkpoint.restore(checkpoint_path)
status.assert_existing_objects_matched()
eval_ds = strategy.experimental_distribute_datasets_from_function(input_fn)
eval_epoch_steps = params.eval_item_count // params.eval_batch_size
eval_result = unet_model.evaluate(
x=eval_ds, steps=eval_epoch_steps, callbacks=unet3d_callbacks(params))
return eval_result
def main(_):
params = extract_params(flags.FLAGS)
assert params.mode in {'train', 'eval'}, 'only support train and eval'
save_params(params)
input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = create_distribution_strategy(params)
with strategy.scope():
unet_model = unet_model_lib.build_unet_model(params)
if params.mode == 'train':
train(params, strategy, unet_model,
functools.partial(get_train_dataset, params),
functools.partial(get_eval_dataset, params))
elif params.mode == 'eval':
evaluate(params, strategy, unet_model,
functools.partial(get_eval_dataset, params))
else:
raise Exception('Only `train` mode and `eval` mode are supported.')
if __name__ == '__main__':
define_unet3d_flags()
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver
from tensorflow.contrib.tpu.python.tpu import device_assignment as device_lib
from tensorflow.python.distribute import tpu_strategy as tpu_strategy_lib
from tensorflow.python.tpu import tpu_strategy_util
from official.modeling.hyperparams import params_dict
from official.vision.segmentation import unet_config
from official.vision.segmentation import unet_main as unet_main_lib
from official.vision.segmentation import unet_metrics
from official.vision.segmentation import unet_model as unet_model_lib
FLAGS = flags.FLAGS
def create_fake_input_fn(params,
features_size,
labels_size,
use_bfloat16=False):
"""Returns fake input function for testing."""
def fake_data_input_fn(unused_ctx=None):
"""An input function for generating fake data."""
batch_size = params.train_batch_size
features = np.random.rand(64, *features_size)
labels = np.random.randint(2, size=[64] + labels_size)
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
def _assign_dtype(features, labels):
if use_bfloat16:
features = tf.cast(features, tf.bfloat16)
labels = tf.cast(labels, tf.bfloat16)
else:
features = tf.cast(features, tf.float32)
labels = tf.cast(labels, tf.float32)
return features, labels
# Shuffle, repeat, and batch the examples.
dataset = dataset.map(_assign_dtype)
dataset = dataset.shuffle(64).repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Return the dataset.
return dataset
return fake_data_input_fn
class UnetMainTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super(UnetMainTest, self).setUp()
self._model_dir = os.path.join(tempfile.mkdtemp(), 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
def tearDown(self):
tf.io.gfile.rmtree(self._model_dir)
super(UnetMainTest, self).tearDown()
@flagsaver.flagsaver
def testUnet3DModel(self):
FLAGS.tpu = ''
FLAGS.mode = 'train'
params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
unet_config.UNET_RESTRICTIONS)
params.override(
{
'input_image_size': [64, 64, 64],
'train_item_count': 4,
'eval_item_count': 4,
'train_batch_size': 2,
'eval_batch_size': 2,
'batch_size': 2,
'num_base_filters': 16,
'dtype': 'bfloat16',
'depth': 1,
'train_steps': 2,
'eval_steps': 2,
'mode': FLAGS.mode,
'tpu': FLAGS.tpu,
'num_gpus': 0,
'checkpoint_interval': 1,
'use_tpu': True,
'input_partition_dims': None,
},
is_strict=False)
params.validate()
params.lock()
image_size = params.input_image_size + [params.num_channels]
label_size = params.input_image_size + [params.num_classes]
input_fn = create_fake_input_fn(
params, features_size=image_size, labels_size=label_size)
resolver = contrib_cluster_resolver.TPUClusterResolver(tpu=params.tpu)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
device_assignment = None
if params.input_partition_dims is not None:
assert np.prod(
params.input_partition_dims) == 2, 'invalid unit test configuration'
computation_shape = [1, 1, 1, 2]
partition_dimension = params.input_partition_dims
num_replicas = resolver.get_tpu_system_metadata().num_cores // np.prod(
partition_dimension)
device_assignment = device_lib.device_assignment(
topology,
computation_shape=computation_shape,
num_replicas=num_replicas)
strategy = tpu_strategy_lib.TPUStrategy(
resolver, device_assignment=device_assignment)
with strategy.scope():
model = unet_model_lib.build_unet_model(params)
optimizer = unet_model_lib.create_optimizer(params.init_learning_rate,
params)
loss_fn = unet_metrics.get_loss_fn(params.mode, params)
model.compile(loss=loss_fn, optimizer=optimizer, metrics=[loss_fn])
eval_ds = input_fn()
iterator = iter(eval_ds)
image, _ = next(iterator)
logits = model(image, training=False)
self.assertEqual(logits.shape[1:], params.input_image_size + [3])
@parameterized.parameters(
{
'use_mlir': True,
'dtype': 'bfloat16',
'input_partition_dims': None,
}, {
'use_mlir': False,
'dtype': 'bfloat16',
'input_partition_dims': None,
}, {
'use_mlir': True,
'dtype': 'bfloat16',
'input_partition_dims': None,
}, {
'use_mlir': False,
'dtype': 'bfloat16',
'input_partition_dims': None,
}, {
'use_mlir': True,
'dtype': 'bfloat16',
'input_partition_dims': [1, 2, 1, 1, 1],
}, {
'use_mlir': False,
'dtype': 'bfloat16',
'input_partition_dims': [1, 2, 1, 1, 1],
}, {
'use_mlir': True,
'dtype': 'bfloat16',
'input_partition_dims': [1, 2, 1, 1, 1],
}, {
'use_mlir': False,
'dtype': 'bfloat16',
'input_partition_dims': [1, 2, 1, 1, 1]
})
@flagsaver.flagsaver
def testUnetTrain(self, use_mlir, dtype, input_partition_dims):
FLAGS.tpu = ''
FLAGS.mode = 'train'
if use_mlir:
tf.config.experimental.enable_mlir_bridge()
params = params_dict.ParamsDict(unet_config.UNET_CONFIG,
unet_config.UNET_RESTRICTIONS)
params.override(
{
'model_dir': self._model_dir,
'input_image_size': [8, 8, 8],
'train_item_count': 2,
'eval_item_count': 2,
'train_batch_size': 2,
'eval_batch_size': 2,
'batch_size': 2,
'num_base_filters': 1,
'dtype': 'bfloat16',
'depth': 1,
'epochs': 1,
'checkpoint_interval': 1,
'train_steps': 1,
'eval_steps': 1,
'mode': FLAGS.mode,
'tpu': FLAGS.tpu,
'use_tpu': True,
'num_gpus': 0,
'distribution_strategy': 'tpu',
'steps_per_loop': 1,
'input_partition_dims': input_partition_dims,
},
is_strict=False)
params.validate()
params.lock()
image_size = params.input_image_size + [params.num_channels]
label_size = params.input_image_size + [params.num_classes]
input_fn = create_fake_input_fn(
params, features_size=image_size, labels_size=label_size)
input_dtype = params.dtype
if input_dtype == 'float16' or input_dtype == 'bfloat16':
policy = tf.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16' if input_dtype == 'bfloat16' else 'mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)
strategy = unet_main_lib.create_distribution_strategy(params)
with strategy.scope():
unet_model = unet_model_lib.build_unet_model(params)
unet_main_lib.train(params, strategy, unet_model, input_fn, input_fn)
if __name__ == '__main__':
unet_main_lib.define_unet3d_flags()
tf.test.main()
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define metrics for the UNet 3D Model."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
def dice(y_true, y_pred, axis=(1, 2, 3, 4)):
"""DICE coefficient.
Taha AA, Hanbury A. Metrics for evaluating 3D medical image segmentation:
analysis, selection, and tool. BMC Med Imaging. 2015;15:29. Published
2015
Aug 12. doi:10.1186/s12880-015-0068-x
Implemented according to
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4533825/#Equ6
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of features.
Returns:
DICE coefficient.
"""
y_true = tf.cast(y_true, y_pred.dtype)
eps = tf.keras.backend.epsilon()
intersection = tf.reduce_sum(input_tensor=y_true * y_pred, axis=axis)
summation = tf.reduce_sum(
input_tensor=y_true, axis=axis) + tf.reduce_sum(
input_tensor=y_pred, axis=axis)
return (2 * intersection + eps) / (summation + eps)
def generalized_dice(y_true, y_pred, axis=(1, 2, 3)):
"""Generalized Dice coefficient, for multi-class predictions.
For output of a multi-class model, where the shape of the output is
(batch, x, y, z, n_classes), the axis argument should be (1, 2, 3).
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of features.
Returns:
DICE coefficient.
"""
y_true = tf.cast(y_true, y_pred.dtype)
if y_true.get_shape().ndims < 2 or y_pred.get_shape().ndims < 2:
raise ValueError('y_true and y_pred must be at least rank 2.')
epsilon = tf.keras.backend.epsilon()
w = tf.math.reciprocal(tf.square(tf.reduce_sum(y_true, axis=axis)) + epsilon)
num = 2 * tf.reduce_sum(
w * tf.reduce_sum(y_true * y_pred, axis=axis), axis=-1)
den = tf.reduce_sum(w * tf.reduce_sum(y_true + y_pred, axis=axis), axis=-1)
return (num + epsilon) / (den + epsilon)
def hamming(y_true, y_pred, axis=(1, 2, 3)):
"""Hamming distance.
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z].
y_pred: the prediction matrix. Shape [batch_size, x, y, z].
axis: a list, axises of the feature dimensions.
Returns:
Hamming distance value.
"""
y_true = tf.cast(y_true, y_pred.dtype)
return tf.reduce_mean(input_tensor=tf.not_equal(y_pred, y_true), axis=axis)
def jaccard(y_true, y_pred, axis=(1, 2, 3, 4)):
"""Jaccard Similarity.
Taha AA, Hanbury A. Metrics for evaluating 3D medical image segmentation:
analysis, selection, and tool. BMC Med Imaging. 2015;15:29. Published
2015
Aug 12. doi:10.1186/s12880-015-0068-x
Implemented according to
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4533825/#Equ7
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of features.
Returns:
Jaccard similarity.
"""
y_true = tf.cast(y_true, y_pred.dtype)
eps = tf.keras.backend.epsilon()
intersection = tf.reduce_sum(input_tensor=y_true * y_pred, axis=axis)
union = tf.reduce_sum(y_true, axis=axis) + tf.reduce_sum(y_pred, axis=axis)
return (intersection + eps) / (union - intersection + eps)
def tversky(y_true, y_pred, axis=(1, 2, 3), alpha=0.3, beta=0.7):
"""Tversky similarity.
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
axis: axises of spatial dimensions.
alpha: weight of the prediction.
beta: weight of the groundtruth.
Returns:
Tversky similarity coefficient.
"""
y_true = tf.cast(y_true, y_pred.dtype)
if y_true.get_shape().ndims < 2 or y_pred.get_shape().ndims < 2:
raise ValueError('y_true and y_pred must be at least rank 2.')
eps = tf.keras.backend.epsilon()
num = tf.reduce_sum(input_tensor=y_pred * y_true, axis=axis)
den = (
num + alpha * tf.reduce_sum(y_pred * (1 - y_true), axis=axis) +
beta * tf.reduce_sum((1 - y_pred) * y_true, axis=axis))
# Sum over classes.
return tf.reduce_sum(input_tensor=(num + eps) / (den + eps), axis=-1)
def adaptive_dice32(y_true, y_pred, data_format='channels_last'):
"""Adaptive dice metric.
Args:
y_true: the ground truth matrix. Shape [batch_size, x, y, z, num_classes].
y_pred: the prediction matrix. Shape [batch_size, x, y, z, num_classes].
data_format: channel last of channel first.
Returns:
Adaptive dice value.
"""
epsilon = 10**-7
y_true = tf.cast(y_true, dtype=y_pred.dtype)
# Determine axes to pass to tf.reduce_sum
if data_format == 'channels_last':
ndim = len(y_pred.shape)
reduction_axes = list(range(ndim - 1))
else:
reduction_axes = 1
# Calculate intersections and unions per class
intersections = tf.reduce_sum(y_true * y_pred, axis=reduction_axes)
unions = tf.reduce_sum(y_true + y_pred, axis=reduction_axes)
# Calculate Dice scores per class
dice_scores = 2.0 * (intersections + epsilon) / (unions + epsilon)
# Calculate weights based on Dice scores
weights = tf.exp(-1.0 * dice_scores)
# Multiply weights by corresponding scores and get sum
weighted_dice = tf.reduce_sum(weights * dice_scores)
# Calculate normalization factor
norm_factor = tf.size(input=dice_scores, out_type=tf.float32) * tf.exp(-1.0)
weighted_dice = tf.cast(weighted_dice, dtype=tf.float32)
# Return 1 - adaptive Dice score
return 1 - (weighted_dice / norm_factor)
def assert_shape_equal(pred_shape, label_shape):
"""Asserts that `pred_shape` and `label_shape` is equal."""
assert (label_shape == pred_shape
), 'pred. shape {} is not equal to label shape {}'.format(
label_shape, pred_shape)
def get_loss_fn(mode, params):
"""Return loss_fn for unet training.
Args:
mode: training or eval. This is a legacy parameter from TF1.
params: unet configuration parameter.
Returns:
loss_fn.
"""
def loss_fn(y_true, y_pred):
"""Returns scalar loss from labels and netowrk outputs."""
loss = None
label_shape = y_true.get_shape().as_list()
pred_shape = y_pred.get_shape().as_list()
assert_shape_equal(label_shape, pred_shape)
if params.loss == 'adaptive_dice32':
loss = adaptive_dice32(y_true, y_pred)
elif params.loss == 'cross_entropy':
if mode == tf.estimator.ModeKeys.TRAIN and params.use_index_label_in_train:
labels_idx = tf.cast(y_true, dtype=tf.int32)
else:
# Use one-hot label representation, convert to label index.
labels_idx = tf.argmax(input=y_true, axis=-1, output_type=tf.int32)
y_pred = tf.cast(y_pred, dtype=tf.float32)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels_idx, y_pred, from_logits=False)
else:
raise Exception('Unexpected loss type')
return loss
return loss_fn
def metric_accuracy(labels, predictions):
"""Returns accuracy metric of model outputs.
Args:
labels: ground truth tensor (labels).
predictions: network output (logits)
Returns:
metric_fn.
"""
if labels.dtype == tf.bfloat16:
labels = tf.cast(labels, tf.float32)
if predictions.dtype == tf.bfloat16:
predictions = tf.cast(predictions, tf.float32)
return tf.keras.backend.mean(
tf.keras.backend.equal(
tf.argmax(input=labels, axis=-1),
tf.argmax(input=predictions, axis=-1)))
def metric_ce(labels, predictions):
"""Returns categorical crossentropy given outputs and labels.
Args:
labels: ground truth tensor (labels).
predictions: network output (logits)
Returns:
metric_fn.
"""
if labels.dtype == tf.bfloat16:
labels = tf.cast(labels, tf.float32)
if predictions.dtype == tf.bfloat16:
predictions = tf.cast(predictions, tf.float32)
return tf.keras.losses.categorical_crossentropy(
labels, predictions, from_logits=False)
def metric_dice(labels, predictions):
"""Returns adaptive dice coefficient."""
if labels.dtype == tf.bfloat16:
labels = tf.cast(labels, tf.float32)
if predictions.dtype == tf.bfloat16:
predictions = tf.cast(predictions, tf.float32)
return adaptive_dice32(labels, predictions)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model definition for the TF2 Keras UNet 3D Model."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
def create_optimizer(init_learning_rate, params):
"""Creates optimizer for training."""
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=init_learning_rate,
decay_steps=params.lr_decay_steps,
decay_rate=params.lr_decay_rate)
# TODO(hongjunchoi): Provide alternative optimizer options depending on model
# config parameters.
optimizer = tf.keras.optimizers.Adam(learning_rate)
return optimizer
def create_convolution_block(input_layer,
n_filters,
batch_normalization=False,
kernel=(3, 3, 3),
activation=tf.nn.relu,
padding='SAME',
strides=(1, 1, 1),
data_format='channels_last',
instance_normalization=False):
"""UNet convolution block.
Args:
input_layer: tf.Tensor, the input tensor.
n_filters: integer, the number of the output channels of the convolution.
batch_normalization: boolean, use batch normalization after the convolution.
kernel: kernel size of the convolution.
activation: Tensorflow activation layer to use. (default is 'relu')
padding: padding type of the convolution.
strides: strides of the convolution.
data_format: data format of the convolution. One of 'channels_first' or
'channels_last'.
instance_normalization: use Instance normalization. Exclusive with batch
normalization.
Returns:
The Tensor after apply the convolution block to the input.
"""
assert instance_normalization == 0, 'TF 2.0 does not support inst. norm.'
layer = tf.keras.layers.Conv3D(
filters=n_filters,
kernel_size=kernel,
strides=strides,
padding=padding,
data_format=data_format,
activation=None,
)(
inputs=input_layer)
if batch_normalization:
layer = tf.keras.layers.BatchNormalization(axis=1)(inputs=layer)
return activation(layer)
def apply_up_convolution(inputs,
num_filters,
pool_size,
kernel_size=(2, 2, 2),
strides=(2, 2, 2),
deconvolution=False):
"""Apply up convolution on inputs.
Args:
inputs: input feature tensor.
num_filters: number of deconvolution output feature channels.
pool_size: pool size of the up-scaling.
kernel_size: kernel size of the deconvolution.
strides: strides of the deconvolution.
deconvolution: Use deconvolution or upsampling.
Returns:
The tensor of the up-scaled features.
"""
if deconvolution:
return tf.keras.layers.Conv3DTranspose(
filters=num_filters, kernel_size=kernel_size, strides=strides)(
inputs=inputs)
else:
return tf.keras.layers.UpSampling3D(size=pool_size)(inputs)
def unet3d_base(input_layer,
pool_size=(2, 2, 2),
n_labels=1,
deconvolution=False,
depth=4,
n_base_filters=32,
batch_normalization=False,
data_format='channels_last'):
"""Builds the 3D UNet Tensorflow model and return the last layer logits.
Args:
input_layer: the input Tensor.
pool_size: Pool size for the max pooling operations.
n_labels: Number of binary labels that the model is learning.
deconvolution: If set to True, will use transpose convolution(deconvolution)
instead of up-sampling. This increases the amount memory required during
training.
depth: indicates the depth of the U-shape for the model. The greater the
depth, the more max pooling layers will be added to the model. Lowering
the depth may reduce the amount of memory required for training.
n_base_filters: The number of filters that the first layer in the
convolution network will have. Following layers will contain a multiple of
this number. Lowering this number will likely reduce the amount of memory
required to train the model.
batch_normalization: boolean. True for use batch normalization after
convolution and before activation.
data_format: string, channel_last (default) or channel_first
Returns:
The last layer logits of 3D UNet.
"""
levels = []
current_layer = input_layer
if data_format == 'channels_last':
channel_dim = -1
else:
channel_dim = 1
# add levels with max pooling
for layer_depth in range(depth):
layer1 = create_convolution_block(
input_layer=current_layer,
n_filters=n_base_filters * (2**layer_depth),
batch_normalization=batch_normalization,
kernel=(3, 3, 3),
activation=tf.nn.relu,
padding='SAME',
strides=(1, 1, 1),
data_format=data_format,
instance_normalization=False)
layer2 = create_convolution_block(
input_layer=layer1,
n_filters=n_base_filters * (2**layer_depth) * 2,
batch_normalization=batch_normalization,
kernel=(3, 3, 3),
activation=tf.nn.relu,
padding='SAME',
strides=(1, 1, 1),
data_format=data_format,
instance_normalization=False)
if layer_depth < depth - 1:
current_layer = tf.keras.layers.MaxPool3D(
pool_size=pool_size,
strides=(2, 2, 2),
padding='VALID',
data_format=data_format)(
inputs=layer2)
levels.append([layer1, layer2, current_layer])
else:
current_layer = layer2
levels.append([layer1, layer2])
# add levels with up-convolution or up-sampling
for layer_depth in range(depth - 2, -1, -1):
up_convolution = apply_up_convolution(
current_layer,
pool_size=pool_size,
deconvolution=deconvolution,
num_filters=current_layer.get_shape().as_list()[channel_dim])
concat = tf.concat([up_convolution, levels[layer_depth][1]],
axis=channel_dim)
current_layer = create_convolution_block(
n_filters=levels[layer_depth][1].get_shape().as_list()[channel_dim],
input_layer=concat,
batch_normalization=batch_normalization,
kernel=(3, 3, 3),
activation=tf.nn.relu,
padding='SAME',
strides=(1, 1, 1),
data_format=data_format,
instance_normalization=False)
current_layer = create_convolution_block(
n_filters=levels[layer_depth][1].get_shape().as_list()[channel_dim],
input_layer=current_layer,
batch_normalization=batch_normalization,
kernel=(3, 3, 3),
activation=tf.nn.relu,
padding='SAME',
strides=(1, 1, 1),
data_format=data_format,
instance_normalization=False)
final_convolution = tf.keras.layers.Conv3D(
filters=n_labels,
kernel_size=(1, 1, 1),
padding='VALID',
data_format=data_format,
activation=None)(
current_layer)
return final_convolution
def build_unet_model(params):
"""Builds the unet model, optimizer included."""
input_shape = params.input_image_size + [1]
input_layer = tf.keras.layers.Input(shape=input_shape)
logits = unet3d_base(
input_layer,
pool_size=(2, 2, 2),
n_labels=params.num_classes,
deconvolution=params.deconvolution,
depth=params.depth,
n_base_filters=params.num_base_filters,
batch_normalization=params.use_batch_norm,
data_format=params.data_format)
# Set output of softmax to float32 to avoid potential numerical overflow.
predictions = tf.keras.layers.Softmax(dtype='float32')(logits)
model = tf.keras.models.Model(inputs=input_layer, outputs=predictions)
model.optimizer = create_optimizer(params.init_learning_rate, params)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment