Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c8a91782
Commit
c8a91782
authored
Jun 25, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jun 25, 2021
Browse files
Internal change
PiperOrigin-RevId: 381516130
parent
6e5cbee1
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1525 additions
and
0 deletions
+1525
-0
official/vision/beta/projects/volumetric_models/README.md
official/vision/beta/projects/volumetric_models/README.md
+31
-0
official/vision/beta/projects/volumetric_models/configs/backbones.py
...sion/beta/projects/volumetric_models/configs/backbones.py
+48
-0
official/vision/beta/projects/volumetric_models/configs/decoders.py
...ision/beta/projects/volumetric_models/configs/decoders.py
+43
-0
official/vision/beta/projects/volumetric_models/configs/semantic_segmentation_3d.py
...cts/volumetric_models/configs/semantic_segmentation_3d.py
+163
-0
official/vision/beta/projects/volumetric_models/configs/semantic_segmentation_3d_test.py
...olumetric_models/configs/semantic_segmentation_3d_test.py
+44
-0
official/vision/beta/projects/volumetric_models/dataloaders/segmentation_input_3d.py
...ts/volumetric_models/dataloaders/segmentation_input_3d.py
+106
-0
official/vision/beta/projects/volumetric_models/dataloaders/segmentation_input_3d_test.py
...lumetric_models/dataloaders/segmentation_input_3d_test.py
+76
-0
official/vision/beta/projects/volumetric_models/evaluation/segmentation_metrics.py
...ects/volumetric_models/evaluation/segmentation_metrics.py
+128
-0
official/vision/beta/projects/volumetric_models/evaluation/segmentation_metrics_test.py
...volumetric_models/evaluation/segmentation_metrics_test.py
+57
-0
official/vision/beta/projects/volumetric_models/losses/segmentation_losses.py
.../projects/volumetric_models/losses/segmentation_losses.py
+105
-0
official/vision/beta/projects/volumetric_models/losses/segmentation_losses_test.py
...ects/volumetric_models/losses/segmentation_losses_test.py
+37
-0
official/vision/beta/projects/volumetric_models/modeling/backbones/__init__.py
...projects/volumetric_models/modeling/backbones/__init__.py
+18
-0
official/vision/beta/projects/volumetric_models/modeling/backbones/unet_3d.py
.../projects/volumetric_models/modeling/backbones/unet_3d.py
+176
-0
official/vision/beta/projects/volumetric_models/modeling/backbones/unet_3d_test.py
...ects/volumetric_models/modeling/backbones/unet_3d_test.py
+74
-0
official/vision/beta/projects/volumetric_models/modeling/decoders/__init__.py
.../projects/volumetric_models/modeling/decoders/__init__.py
+18
-0
official/vision/beta/projects/volumetric_models/modeling/decoders/factory.py
...a/projects/volumetric_models/modeling/decoders/factory.py
+60
-0
official/vision/beta/projects/volumetric_models/modeling/decoders/unet_3d_decoder.py
...ts/volumetric_models/modeling/decoders/unet_3d_decoder.py
+154
-0
official/vision/beta/projects/volumetric_models/modeling/decoders/unet_3d_decoder_test.py
...lumetric_models/modeling/decoders/unet_3d_decoder_test.py
+81
-0
official/vision/beta/projects/volumetric_models/modeling/factory.py
...ision/beta/projects/volumetric_models/modeling/factory.py
+61
-0
official/vision/beta/projects/volumetric_models/modeling/factory_test.py
.../beta/projects/volumetric_models/modeling/factory_test.py
+45
-0
No files found.
official/vision/beta/projects/volumetric_models/README.md
0 → 100644
View file @
c8a91782
# Volumetric Models
**DISCLAIMER**
: This implementation is still under development. No support will
be provided during the development phase.
This folder contains implementation of volumetric models, i.e., UNet 3D model,
for 3D semantic segmentation.
## Modeling
Following the style of TF-Vision, a UNet 3D model is implemented as a backbone
and a decoder.
## Backbone
The backbone is the left U-shape of the complete UNet model. It takes batch of
images as input, and outputs a dictionary in a form of
`{level: features}`
.
`features`
in the output is a tensor of feature maps.
## Decoder
The decoder is the right U-shape of the complete UNet model. It takes the output
dictionary from the backbone and connects the feature maps from each level to
the decoder's decoding branches. The final output is the raw segmentation
predictions.
An additional head is attached to the output of the decoder to optionally
perform more operations and then generate the prediction map of logits.
The
`factory.py`
file builds and connects the backbone, decoder and head
together to form the complete UNet model.
official/vision/beta/projects/volumetric_models/configs/backbones.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones configurations."""
from
typing
import
Optional
,
Sequence
import
dataclasses
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
UNet3D
(
hyperparams
.
Config
):
"""UNet3D config."""
model_id
:
int
=
4
pool_size
:
Sequence
[
int
]
=
(
2
,
2
,
2
)
kernel_size
:
Sequence
[
int
]
=
(
3
,
3
,
3
)
base_filters
:
int
=
32
use_batch_normalization
:
bool
=
True
@
dataclasses
.
dataclass
class
Backbone
(
hyperparams
.
OneOfConfig
):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, one the of fields below.
resnet: resnet backbone config.
dilated_resnet: dilated resnet backbone for semantic segmentation config.
revnet: revnet backbone config.
efficientnet: efficientnet backbone config.
spinenet: spinenet backbone config.
mobilenet: mobilenet backbone config.
"""
type
:
Optional
[
str
]
=
None
unet_3d
:
UNet3D
=
UNet3D
()
official/vision/beta/projects/volumetric_models/configs/decoders.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Decoders configurations."""
from
typing
import
Optional
,
Sequence
import
dataclasses
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
UNet3DDecoder
(
hyperparams
.
Config
):
"""UNet3D decoder config."""
model_id
:
int
=
4
pool_size
:
Sequence
[
int
]
=
(
2
,
2
,
2
)
kernel_size
:
Sequence
[
int
]
=
(
3
,
3
,
3
)
use_batch_normalization
:
bool
=
True
use_deconvolution
:
bool
=
True
@
dataclasses
.
dataclass
class
Decoder
(
hyperparams
.
OneOfConfig
):
"""Configuration for decoders.
Attributes:
type: 'str', type of decoder be used, on the of fields below.
fpn: fpn config.
"""
type
:
Optional
[
str
]
=
None
unet_3d_decoder
:
UNet3DDecoder
=
UNet3DDecoder
()
official/vision/beta/projects/volumetric_models/configs/semantic_segmentation_3d.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Semantic segmentation configuration definition."""
from
typing
import
List
,
Optional
,
Union
import
dataclasses
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.vision.beta.configs
import
common
from
official.vision.beta.projects.volumetric_models.configs
import
backbones
from
official.vision.beta.projects.volumetric_models.configs
import
decoders
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Input config for training."""
output_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
num_classes
:
int
=
0
num_channels
:
int
=
1
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
True
dtype
:
str
=
'float32'
label_dtype
:
str
=
'float32'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
shuffle_buffer_size
:
int
=
1000
cycle_length
:
int
=
10
drop_remainder
:
bool
=
False
file_type
:
str
=
'tfrecord'
@
dataclasses
.
dataclass
class
SegmentationHead3D
(
hyperparams
.
Config
):
"""Segmentation head config."""
num_classes
:
int
=
0
level
:
int
=
1
num_convs
:
int
=
0
num_filters
:
int
=
256
upsample_factor
:
int
=
1
output_logits
:
bool
=
True
@
dataclasses
.
dataclass
class
SemanticSegmentationModel3D
(
hyperparams
.
Config
):
"""Semantic segmentation model config."""
num_classes
:
int
=
0
num_channels
:
int
=
1
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
min_level
:
int
=
3
max_level
:
int
=
6
head
:
SegmentationHead3D
=
SegmentationHead3D
()
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'unet_3d'
,
unet_3d
=
backbones
.
UNet3D
())
decoder
:
decoders
.
Decoder
=
decoders
.
Decoder
(
type
=
'unet_3d_decoder'
,
unet_3d_decoder
=
decoders
.
UNet3DDecoder
())
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
()
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
# Supported `loss_type` are `adaptive` and `generalized`.
loss_type
:
str
=
'adaptive'
l2_weight_decay
:
float
=
0.0
@
dataclasses
.
dataclass
class
Evaluation
(
hyperparams
.
Config
):
report_per_class_metric
:
bool
=
False
# Whether to report per-class metrics.
@
dataclasses
.
dataclass
class
SemanticSegmentation3DTask
(
cfg
.
TaskConfig
):
"""The model config."""
model
:
SemanticSegmentationModel3D
=
SemanticSegmentationModel3D
()
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
evaluation
:
Evaluation
=
Evaluation
()
train_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
eval_input_partition_dims
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
Union
[
str
,
List
[
str
]]
=
'all'
# all, backbone, and/or decoder
@
exp_factory
.
register_config_factory
(
'seg_unet3d_test'
)
def
seg_unet3d_test
()
->
cfg
.
ExperimentConfig
:
"""Image segmentation on a dummy dataset with 3D UNet for testing purpose."""
train_batch_size
=
2
eval_batch_size
=
2
steps_per_epoch
=
10
config
=
cfg
.
ExperimentConfig
(
task
=
SemanticSegmentation3DTask
(
model
=
SemanticSegmentationModel3D
(
num_classes
=
2
,
input_size
=
[
32
,
32
,
32
],
num_channels
=
2
,
backbone
=
backbones
.
Backbone
(
type
=
'unet_3d'
,
unet_3d
=
backbones
.
UNet3D
(
model_id
=
2
)),
decoder
=
decoders
.
Decoder
(
type
=
'unet_3d_decoder'
,
unet_3d_decoder
=
decoders
.
UNet3DDecoder
(
model_id
=
2
)),
head
=
SegmentationHead3D
(
num_convs
=
0
,
num_classes
=
2
),
norm_activation
=
common
.
NormActivation
(
activation
=
'relu'
,
use_sync_bn
=
False
)),
train_data
=
DataConfig
(
input_path
=
'train.tfrecord'
,
num_classes
=
2
,
input_size
=
[
32
,
32
,
32
],
num_channels
=
2
,
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
input_path
=
'val.tfrecord'
,
num_classes
=
2
,
input_size
=
[
32
,
32
,
32
],
num_channels
=
2
,
is_training
=
False
,
global_batch_size
=
eval_batch_size
),
losses
=
Losses
(
loss_type
=
'adaptive'
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
10
,
validation_steps
=
10
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
,
},
'learning_rate'
:
{
'type'
:
'constant'
,
'constant'
:
{
'learning_rate'
:
0.000001
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/vision/beta/projects/volumetric_models/configs/semantic_segmentation_3d_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for semantic_segmentation."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
class
ImageSegmentationConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
'seg_unet3d_test'
,),)
def
test_semantic_segmentation_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
SemanticSegmentation3DTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
SemanticSegmentationModel3D
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/dataloaders/segmentation_input_3d.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data parser and processing for 3D segmentation datasets."""
from
typing
import
Any
,
Dict
,
Sequence
,
Tuple
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
class
Decoder
(
decoder
.
Decoder
):
"""A tf.Example decoder for segmentation task."""
def
__init__
(
self
,
image_field_key
:
str
=
'image/encoded'
,
label_field_key
:
str
=
'image/class/label'
):
self
.
_keys_to_features
=
{
image_field_key
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
,
default_value
=
''
),
label_field_key
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
,
default_value
=
''
)
}
def
decode
(
self
,
serialized_example
:
tf
.
string
)
->
Dict
[
str
,
tf
.
Tensor
]:
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
class
Parser
(
parser
.
Parser
):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
input_size
:
Sequence
[
int
],
num_classes
:
int
,
num_channels
:
int
=
3
,
image_field_key
:
str
=
'image/encoded'
,
label_field_key
:
str
=
'image/class/label'
,
dtype
:
str
=
'float32'
,
label_dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
Args:
input_size: The input tensor size of [height, width, volume] of input
image.
num_classes: The number of classes to be segmented.
num_channels: The channel of input images.
image_field_key: A `str` of the key name to encoded image in TFExample.
label_field_key: A `str` of the key name to label in TFExample.
dtype: The data type. One of {`bfloat16`, `float32`, `float16`}.
label_dtype: The data type of input label.
"""
self
.
_input_size
=
input_size
self
.
_num_classes
=
num_classes
self
.
_num_channels
=
num_channels
self
.
_image_field_key
=
image_field_key
self
.
_label_field_key
=
label_field_key
self
.
_dtype
=
dtype
self
.
_label_dtype
=
label_dtype
def
_prepare_image_and_label
(
self
,
data
:
Dict
[
str
,
Any
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Prepares normalized image and label."""
image
=
tf
.
io
.
decode_raw
(
data
[
self
.
_image_field_key
],
tf
.
as_dtype
(
tf
.
float32
))
label
=
tf
.
io
.
decode_raw
(
data
[
self
.
_label_field_key
],
tf
.
as_dtype
(
self
.
_label_dtype
))
image_size
=
list
(
self
.
_input_size
)
+
[
self
.
_num_channels
]
image
=
tf
.
reshape
(
image
,
image_size
)
label_size
=
list
(
self
.
_input_size
)
+
[
self
.
_num_classes
]
label
=
tf
.
reshape
(
label
,
label_size
)
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
label
=
tf
.
cast
(
label
,
dtype
=
self
.
_dtype
)
# TPU doesn't support tf.int64 well, use tf.int32 directly.
if
label
.
dtype
==
tf
.
int64
:
label
=
tf
.
cast
(
label
,
dtype
=
tf
.
int32
)
return
image
,
label
def
_parse_train_data
(
self
,
data
:
Dict
[
str
,
Any
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Parses data for training and evaluation."""
image
,
labels
=
self
.
_prepare_image_and_label
(
data
)
# Cast image as self._dtype
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
return
image
,
labels
def
_parse_eval_data
(
self
,
data
:
Dict
[
str
,
Any
])
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Parses data for training and evaluation."""
image
,
labels
=
self
.
_prepare_image_and_label
(
data
)
# Cast image as self._dtype
image
=
tf
.
cast
(
image
,
dtype
=
self
.
_dtype
)
return
image
,
labels
official/vision/beta/projects/volumetric_models/dataloaders/segmentation_input_3d_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for segmentation_input_3d.py."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.volumetric_models.dataloaders
import
segmentation_input_3d
class
InputReaderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
# pylint: disable=g-complex-comprehension
examples
=
[
tfexample_utils
.
create_3d_image_test_example
(
image_height
=
32
,
image_width
=
32
,
image_volume
=
32
,
image_channel
=
2
)
for
_
in
range
(
20
)
]
# pylint: enable=g-complex-comprehension
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
@
parameterized
.
parameters
(([
32
,
32
,
32
],
2
,
2
))
def
testSegmentationInputReader
(
self
,
input_size
,
num_classes
,
num_channels
):
params
=
cfg
.
DataConfig
(
input_path
=
self
.
_data_path
,
global_batch_size
=
2
,
is_training
=
False
)
decoder
=
segmentation_input_3d
.
Decoder
()
parser
=
segmentation_input_3d
.
Parser
(
input_size
=
input_size
,
num_classes
=
num_classes
,
num_channels
=
num_channels
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
'tfrecord'
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
()
iterator
=
iter
(
dataset
)
image
,
labels
=
next
(
iterator
)
# Checks image shape.
self
.
assertEqual
(
list
(
image
.
numpy
().
shape
),
[
2
,
input_size
[
0
],
input_size
[
1
],
input_size
[
2
],
num_channels
])
self
.
assertEqual
(
list
(
labels
.
numpy
().
shape
),
[
2
,
input_size
[
0
],
input_size
[
1
],
input_size
[
2
],
num_classes
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/evaluation/segmentation_metrics.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metrics for segmentation."""
from
typing
import
Optional
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.losses
import
segmentation_losses
class
DiceScore
:
"""Dice score metric for semantic segmentation.
This class follows the same function interface as tf.keras.metrics.Metric but
does not derive from tf.keras.metrics.Metric or utilize its functions. The
reason is a tf.keras.metrics.Metric object does not run well on CPU while
created on GPU, when running with MirroredStrategy. The same interface allows
for minimal change to the upstream tasks.
Attributes:
name: The name of the metric.
dtype: The dtype of the metric, for example, tf.float32.
"""
def
__init__
(
self
,
num_classes
:
int
,
metric_type
:
Optional
[
str
]
=
None
,
per_class_metric
:
bool
=
False
,
name
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
):
"""Constructs segmentation evaluator class.
Args:
num_classes: The number of classes.
metric_type: An optional `str` of type of dice scores.
per_class_metric: Whether to report per-class metric.
name: A `str`, name of the metric instance..
dtype: The data type of the metric result.
"""
self
.
_num_classes
=
num_classes
self
.
_per_class_metric
=
per_class_metric
self
.
_dice_op_overall
=
segmentation_losses
.
SegmentationLossDiceScore
(
metric_type
=
metric_type
)
self
.
_dice_scores_overall
=
tf
.
Variable
(
0.0
)
self
.
_count
=
tf
.
Variable
(
0.0
)
if
self
.
_per_class_metric
:
# Always use raw dice score for per-class metrics, so metric_type is None
# by default.
self
.
_dice_op_per_class
=
segmentation_losses
.
SegmentationLossDiceScore
()
self
.
_dice_scores_per_class
=
[
tf
.
Variable
(
0.0
)
for
_
in
range
(
num_classes
)
]
self
.
name
=
name
self
.
dtype
=
dtype
def
update_state
(
self
,
y_true
:
tf
.
Tensor
,
y_pred
:
tf
.
Tensor
):
"""Updates metric state.
Args:
y_true: The true labels of size [batch, width, height, volume,
num_classes].
y_pred: The prediction of size [batch, width, height, volume,
num_classes].
Raises:
ValueError: If number of classes from groundtruth label does not equal to
`num_classes`.
"""
if
self
.
_num_classes
!=
y_true
.
get_shape
()[
-
1
]:
raise
ValueError
(
'The number of classes from groundtruth labels and `num_classes` '
'should equal, but they are {0} and {1}.'
.
format
(
self
.
_num_classes
,
y_true
.
get_shape
()[
-
1
]))
self
.
_count
.
assign_add
(
1.
)
self
.
_dice_scores_overall
.
assign_add
(
1
-
self
.
_dice_op_overall
(
y_pred
,
y_true
))
if
self
.
_per_class_metric
:
for
class_id
in
range
(
self
.
_num_classes
):
self
.
_dice_scores_per_class
[
class_id
].
assign_add
(
1
-
self
.
_dice_op_per_class
(
y_pred
[...,
class_id
],
y_true
[...,
class_id
]))
def
result
(
self
)
->
tf
.
Tensor
:
"""Computes and returns the metric.
The first one is `generalized` or `adaptive` overall dice score, depending
on `metric_type`. If `per_class_metric` is True, `num_classes` elements are
also appended to the overall metric, as the per-class raw dice scores.
Returns:
The resulting dice scores.
"""
if
self
.
_per_class_metric
:
dice_scores
=
[
tf
.
math
.
divide_no_nan
(
self
.
_dice_scores_overall
,
self
.
_count
)
]
for
class_id
in
range
(
self
.
_num_classes
):
dice_scores
.
append
(
tf
.
math
.
divide_no_nan
(
self
.
_dice_scores_per_class
[
class_id
],
self
.
_count
))
return
tf
.
stack
(
dice_scores
)
else
:
return
tf
.
math
.
divide_no_nan
(
self
.
_dice_scores_overall
,
self
.
_count
)
def
reset_states
(
self
):
"""Resets the metrcis to the initial state."""
self
.
_count
=
tf
.
Variable
(
0.0
)
self
.
_dice_scores_overall
=
tf
.
Variable
(
0.0
)
if
self
.
_per_class_metric
:
for
class_id
in
range
(
self
.
_num_classes
):
self
.
_dice_scores_per_class
[
class_id
]
=
tf
.
Variable
(
0.0
)
official/vision/beta/projects/volumetric_models/evaluation/segmentation_metrics_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for segmentation_losses.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.evaluation
import
segmentation_metrics
class
SegmentationMetricsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
((
1
,
'generalized'
,
0.5
,
[
0.74
,
0.67
]),
(
1
,
'adaptive'
,
0.5
,
[
0.93
,
0.67
]),
(
2
,
None
,
0.5
,
[
0.67
,
0.67
,
0.67
]),
(
3
,
'generalized'
,
0.5
,
[
0.7
,
0.67
,
0.67
,
0.67
]))
def
test_forward_dice_score
(
self
,
num_classes
,
metric_type
,
output
,
expected_score
):
metric
=
segmentation_metrics
.
DiceScore
(
num_classes
=
num_classes
,
metric_type
=
metric_type
,
per_class_metric
=
True
)
y_pred
=
tf
.
constant
(
output
,
shape
=
[
2
,
128
,
128
,
128
,
num_classes
],
dtype
=
tf
.
float32
)
y_true
=
tf
.
ones
(
shape
=
[
2
,
128
,
128
,
128
,
num_classes
],
dtype
=
tf
.
float32
)
metric
.
update_state
(
y_true
=
y_true
,
y_pred
=
y_pred
)
actual_score
=
metric
.
result
().
numpy
()
self
.
assertAllClose
(
actual_score
,
expected_score
,
atol
=
1e-2
,
msg
=
'Output metric {} does not match expected metric {}.'
.
format
(
actual_score
,
expected_score
))
def
test_num_classes_not_equal
(
self
):
metric
=
segmentation_metrics
.
DiceScore
(
num_classes
=
4
)
y_pred
=
tf
.
constant
(
0.5
,
shape
=
[
2
,
128
,
128
,
128
,
2
],
dtype
=
tf
.
float32
)
y_true
=
tf
.
ones
(
shape
=
[
2
,
128
,
128
,
128
,
2
],
dtype
=
tf
.
float32
)
with
self
.
assertRaisesRegex
(
ValueError
,
'The number of classes from groundtruth labels and `num_classes` '
'should equal'
):
metric
.
update_state
(
y_true
=
y_true
,
y_pred
=
y_pred
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/losses/segmentation_losses.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses used for segmentation models."""
from
typing
import
Optional
,
Sequence
import
tensorflow
as
tf
class
SegmentationLossDiceScore
(
object
):
"""Semantic segmentation loss using generalized dice score.
Dice score (DSC) is a similarity measure that equals twice the number of
elements common to both sets divided by the sum of the number of elements
in each set. It is commonly used to evaluate segmentation performance to
measure the overlap of predicted and groundtruth regions.
(https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient)
Generalized dice score is the dice score weighted by the volume of groundtruth
labels per class. Adaptive dice score adds weights to generalized dice score.
It assigns larger weights to lower dice score, so that wrong predictions
contribute more to the total loss. Model will then be trained to focus more on
these hard examples.
"""
def
__init__
(
self
,
metric_type
:
Optional
[
str
]
=
None
,
axis
:
Optional
[
Sequence
[
int
]]
=
(
1
,
2
,
3
)):
"""Initializes dice score loss object.
Args:
metric_type: An optional `str` specifying the type of the dice score to
compute. Compute generalized or adaptive dice score if metric type is
`generalized` or `adaptive`; otherwise compute original dice score.
axis: An optional sequence of `int` specifying the axis to perform reduce
ops for raw dice score.
"""
self
.
_dice_score
=
0
self
.
_metric_type
=
metric_type
self
.
_axis
=
axis
def
__call__
(
self
,
logits
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Computes and returns a loss based on 1 - dice score.
Args:
logits: A Tensor of the prediction.
labels: A Tensor of the groundtruth label.
Returns:
The loss value of (1 - dice score).
"""
labels
=
tf
.
cast
(
labels
,
logits
.
dtype
)
if
labels
.
get_shape
().
ndims
<
2
or
logits
.
get_shape
().
ndims
<
2
:
raise
ValueError
(
'The labels and logits must be at least rank 2.'
)
epsilon
=
tf
.
keras
.
backend
.
epsilon
()
axis
=
list
(
range
(
len
(
logits
.
shape
)
-
1
))
# Calculate intersections and unions per class.
intersection
=
tf
.
reduce_sum
(
labels
*
logits
,
axis
=
axis
)
union
=
tf
.
reduce_sum
(
labels
+
logits
,
axis
=
axis
)
if
self
.
_metric_type
==
'generalized'
:
# Calculate the volume of groundtruth labels.
w
=
tf
.
math
.
reciprocal
(
tf
.
square
(
tf
.
reduce_sum
(
labels
,
axis
=
axis
))
+
epsilon
)
# Calculate the weighted dice score and normalizer.
dice
=
2
*
tf
.
reduce_sum
(
w
*
intersection
)
+
epsilon
normalizer
=
tf
.
reduce_sum
(
w
*
union
)
+
epsilon
dice
=
tf
.
cast
(
dice
,
dtype
=
tf
.
float32
)
normalizer
=
tf
.
cast
(
normalizer
,
dtype
=
tf
.
float32
)
return
1
-
tf
.
reduce_mean
(
dice
/
normalizer
)
elif
self
.
_metric_type
==
'adaptive'
:
dice
=
2.0
*
(
intersection
+
epsilon
)
/
(
union
+
epsilon
)
# Calculate weights based on Dice scores.
weights
=
tf
.
exp
(
-
1.0
*
dice
)
# Multiply weights by corresponding scores and get sum.
weighted_dice
=
tf
.
reduce_sum
(
weights
*
dice
)
# Calculate normalization factor.
normalizer
=
tf
.
cast
(
tf
.
size
(
input
=
dice
),
dtype
=
tf
.
float32
)
*
tf
.
exp
(
-
1.0
)
weighted_dice
=
tf
.
cast
(
weighted_dice
,
dtype
=
tf
.
float32
)
return
1
-
tf
.
reduce_mean
(
weighted_dice
/
normalizer
)
else
:
summation
=
tf
.
reduce_sum
(
labels
,
axis
=
self
.
_axis
)
+
tf
.
reduce_sum
(
logits
,
axis
=
self
.
_axis
)
dice
=
(
2
*
tf
.
reduce_sum
(
labels
*
logits
,
axis
=
self
.
_axis
)
+
epsilon
)
/
(
summation
+
epsilon
)
return
1
-
tf
.
reduce_mean
(
dice
)
official/vision/beta/projects/volumetric_models/losses/segmentation_losses_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for segmentation_losses.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.losses
import
segmentation_losses
class
SegmentationLossDiceScoreTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
((
None
,
0.5
,
0.3
),
(
'generalized'
,
0.5
,
0.3
),
(
'adaptive'
,
0.5
,
0.07
))
def
test_supported_loss
(
self
,
metric_type
,
output
,
expected_score
):
loss
=
segmentation_losses
.
SegmentationLossDiceScore
(
metric_type
=
metric_type
)
logits
=
tf
.
constant
(
output
,
shape
=
[
1
,
128
,
128
,
128
,
1
],
dtype
=
tf
.
float32
)
labels
=
tf
.
ones
(
shape
=
[
1
,
128
,
128
,
128
,
1
],
dtype
=
tf
.
float32
)
actual_score
=
loss
(
logits
=
logits
,
labels
=
labels
)
self
.
assertAlmostEqual
(
actual_score
.
numpy
(),
expected_score
,
places
=
1
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/backbones/__init__.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones package definition."""
from
official.vision.beta.projects.volumetric_models.modeling.backbones.unet_3d
import
UNet3D
official/vision/beta/projects/volumetric_models/modeling/backbones/unet_3d.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of 3D UNet Model encoder part.
[1] Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf
Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650.
"""
from
typing
import
Any
,
Mapping
,
Sequence
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.projects.volumetric_models.modeling
import
nn_blocks_3d
layers
=
tf
.
keras
.
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
UNet3D
(
tf
.
keras
.
Model
):
"""Class to build 3D UNet backbone."""
def
__init__
(
self
,
model_id
:
int
,
input_specs
:
layers
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
None
,
3
]),
pool_size
:
Sequence
[
int
]
=
(
2
,
2
,
2
),
kernel_size
:
Sequence
[
int
]
=
(
3
,
3
,
3
),
base_filters
:
int
=
32
,
kernel_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
activation
:
str
=
'relu'
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
use_sync_bn
:
bool
=
False
,
use_batch_normalization
:
bool
=
False
,
**
kwargs
):
"""3D UNet backbone initialization function.
Args:
model_id: The depth of UNet3D backbone model. The greater the depth, the
more max pooling layers will be added to the model. Lowering the depth
may reduce the amount of memory required for training.
input_specs: The specs of the input tensor. It specifies a 5D input of
[batch, height, width, volume, channel] for `channel_last` data format
or [batch, channel, height, width, volume] for `channel_first` data
format.
pool_size: The pooling size for the max pooling operations.
kernel_size: The kernel size for 3D convolution.
base_filters: The number of filters that the first layer in the
convolution network will have. Following layers will contain a multiple
of this number. Lowering this number will likely reduce the amount of
memory required to train the model.
kernel_regularizer: A tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
activation: The name of the activation function.
norm_momentum: The normalization momentum for the moving average.
norm_epsilon: A float added to variance to avoid dividing by zero.
use_sync_bn: If True, use synchronized batch normalization.
use_batch_normalization: If set to True, use batch normalization after
convolution and before activation. Default to False.
**kwargs: Keyword arguments to be passed.
"""
self
.
_model_id
=
model_id
self
.
_input_specs
=
input_specs
self
.
_pool_size
=
pool_size
self
.
_kernel_size
=
kernel_size
self
.
_activation
=
activation
self
.
_base_filters
=
base_filters
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_use_sync_bn
=
use_sync_bn
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
layers
.
BatchNormalization
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_use_batch_normalization
=
use_batch_normalization
# Build 3D UNet.
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:],
dtype
=
input_specs
.
dtype
)
x
=
inputs
endpoints
=
{}
# Add levels with max pooling to downsample input.
for
layer_depth
in
range
(
model_id
):
# Two convoluions are applied sequentially without downsampling.
filter_num
=
base_filters
*
(
2
**
layer_depth
)
x2
=
nn_blocks_3d
.
BasicBlock3DVolume
(
filters
=
[
filter_num
,
filter_num
*
2
],
strides
=
(
1
,
1
,
1
),
kernel_size
=
self
.
_kernel_size
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
activation
=
self
.
_activation
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
,
use_batch_normalization
=
self
.
_use_batch_normalization
)(
x
)
if
layer_depth
<
model_id
-
1
:
x
=
layers
.
MaxPool3D
(
pool_size
=
pool_size
,
strides
=
(
2
,
2
,
2
),
padding
=
'valid'
,
data_format
=
tf
.
keras
.
backend
.
image_data_format
())(
x2
)
else
:
x
=
x2
endpoints
[
str
(
layer_depth
+
1
)]
=
x2
self
.
_output_specs
=
{
l
:
endpoints
[
l
].
get_shape
()
for
l
in
endpoints
}
super
(
UNet3D
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
,
**
kwargs
)
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
{
'model_id'
:
self
.
_model_id
,
'pool_size'
:
self
.
_pool_size
,
'kernel_size'
:
self
.
_kernel_size
,
'activation'
:
self
.
_activation
,
'base_filters'
:
self
.
_base_filters
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'use_batch_normalization'
:
self
.
_use_batch_normalization
}
@
classmethod
def
from_config
(
cls
,
config
:
Mapping
[
str
,
Any
],
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
"""Returns a dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
@
factory
.
register_backbone_builder
(
'unet_3d'
)
def
build_unet3d
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
backbone_config
:
hyperparams
.
Config
,
norm_activation_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds 3D UNet backbone from a config."""
backbone_type
=
backbone_config
.
type
backbone_cfg
=
backbone_config
.
get
()
assert
backbone_type
==
'unet_3d'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
return
UNet3D
(
model_id
=
backbone_cfg
.
model_id
,
input_specs
=
input_specs
,
pool_size
=
backbone_cfg
.
pool_size
,
base_filters
=
backbone_cfg
.
base_filters
,
kernel_regularizer
=
l2_regularizer
,
activation
=
norm_activation_config
.
activation
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_batch_normalization
=
backbone_cfg
.
use_batch_normalization
)
official/vision/beta/projects/volumetric_models/modeling/backbones/unet_3d_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D UNet backbone."""
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
class
UNet3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
([
128
,
64
],
4
),
([
256
,
128
],
6
),
)
def
test_network_creation
(
self
,
input_size
,
model_id
):
"""Test creation of UNet3D family models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
network
=
unet_3d
.
UNet3D
(
model_id
=
model_id
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
[
0
],
input_size
[
0
],
input_size
[
1
],
3
),
batch_size
=
1
)
endpoints
=
network
(
inputs
)
for
layer_depth
in
range
(
model_id
):
self
.
assertAllEqual
([
1
,
input_size
[
0
]
/
2
**
layer_depth
,
input_size
[
0
]
/
2
**
layer_depth
,
input_size
[
1
]
/
2
**
layer_depth
,
64
*
2
**
layer_depth
],
endpoints
[
str
(
layer_depth
+
1
)].
shape
.
as_list
())
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
model_id
=
4
,
pool_size
=
(
2
,
2
,
2
),
kernel_size
=
(
3
,
3
,
3
),
activation
=
'relu'
,
base_filters
=
32
,
kernel_regularizer
=
None
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
use_sync_bn
=
False
,
use_batch_normalization
=
True
)
network
=
unet_3d
.
UNet3D
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
unet_3d
.
UNet3D
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
new_network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/decoders/__init__.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Decoders package definition."""
from
official.vision.beta.projects.volumetric_models.modeling.decoders.unet_3d_decoder
import
UNet3DDecoder
official/vision/beta/projects/volumetric_models/modeling/decoders/factory.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""factory method."""
from
typing
import
Mapping
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
def
build_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds decoder from a config.
Args:
input_specs: `dict` input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
Returns:
A tf.keras.Model instance of the decoder.
"""
decoder_type
=
model_config
.
decoder
.
type
decoder_cfg
=
model_config
.
decoder
.
get
()
norm_activation_config
=
model_config
.
norm_activation
if
decoder_type
==
'identity'
:
decoder
=
None
elif
decoder_type
==
'unet_3d_decoder'
:
decoder
=
decoders
.
UNet3DDecoder
(
model_id
=
decoder_cfg
.
model_id
,
input_specs
=
input_specs
,
pool_size
=
decoder_cfg
.
pool_size
,
kernel_regularizer
=
l2_regularizer
,
activation
=
norm_activation_config
.
activation
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_batch_normalization
=
decoder_cfg
.
use_batch_normalization
,
use_deconvolution
=
decoder_cfg
.
use_deconvolution
)
else
:
raise
ValueError
(
'Decoder {!r} not implement'
.
format
(
decoder_type
))
return
decoder
official/vision/beta/projects/volumetric_models/modeling/decoders/unet_3d_decoder.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of 3D UNet Model decoder part.
[1] Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf
Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650.
"""
from
typing
import
Any
,
Sequence
,
Dict
,
Mapping
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling
import
nn_blocks_3d
layers
=
tf
.
keras
.
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
UNet3DDecoder
(
tf
.
keras
.
Model
):
"""Class to build 3D UNet decoder."""
def
__init__
(
self
,
model_id
:
int
,
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
pool_size
:
Sequence
[
int
]
=
(
2
,
2
,
2
),
kernel_size
:
Sequence
[
int
]
=
(
3
,
3
,
3
),
kernel_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
activation
:
str
=
'relu'
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
use_sync_bn
:
bool
=
False
,
use_batch_normalization
:
bool
=
False
,
use_deconvolution
:
bool
=
False
,
**
kwargs
):
"""3D UNet decoder initialization function.
Args:
model_id: The depth of UNet3D backbone model. The greater the depth, the
more max pooling layers will be added to the model. Lowering the depth
may reduce the amount of memory required for training.
input_specs: The input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
pool_size: The pooling size for the max pooling operations.
kernel_size: The kernel size for 3D convolution.
kernel_regularizer: A tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
activation: The name of the activation function.
norm_momentum: The normalization momentum for the moving average.
norm_epsilon: A float added to variance to avoid dividing by zero.
use_sync_bn: If True, use synchronized batch normalization.
use_batch_normalization: If set to True, use batch normalization after
convolution and before activation. Default to False.
use_deconvolution: If set to True, the model will use transpose
convolution (deconvolution) instead of up-sampling. This increases the
amount memory required during training. Default to False.
**kwargs: Keyword arguments to be passed.
"""
self
.
_config_dict
=
{
'model_id'
:
model_id
,
'input_specs'
:
input_specs
,
'pool_size'
:
pool_size
,
'kernel_size'
:
kernel_size
,
'kernel_regularizer'
:
kernel_regularizer
,
'activation'
:
activation
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'use_sync_bn'
:
use_sync_bn
,
'use_batch_normalization'
:
use_batch_normalization
,
'use_deconvolution'
:
use_deconvolution
}
if
use_sync_bn
:
self
.
_norm
=
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
layers
.
BatchNormalization
self
.
_use_batch_normalization
=
use_batch_normalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
channel_dim
=
-
1
else
:
channel_dim
=
1
# Build 3D UNet.
inputs
=
self
.
_build_input_pyramid
(
input_specs
,
model_id
)
# Add levels with up-convolution or up-sampling.
x
=
inputs
[
str
(
model_id
)]
for
layer_depth
in
range
(
model_id
-
1
,
0
,
-
1
):
# Apply deconvolution or upsampling.
if
use_deconvolution
:
x
=
layers
.
Conv3DTranspose
(
filters
=
x
.
get_shape
().
as_list
()[
channel_dim
],
kernel_size
=
pool_size
,
strides
=
(
2
,
2
,
2
))(
x
)
else
:
x
=
layers
.
UpSampling3D
(
size
=
pool_size
)(
x
)
# Concatenate upsampled features with input features from one layer up.
x
=
tf
.
concat
([
x
,
tf
.
cast
(
inputs
[
str
(
layer_depth
)],
dtype
=
x
.
dtype
)],
axis
=
channel_dim
)
filter_num
=
inputs
[
str
(
layer_depth
)].
get_shape
().
as_list
()[
channel_dim
]
x
=
nn_blocks_3d
.
BasicBlock3DVolume
(
filters
=
[
filter_num
,
filter_num
],
strides
=
(
1
,
1
,
1
),
kernel_size
=
kernel_size
,
kernel_regularizer
=
kernel_regularizer
,
activation
=
activation
,
use_sync_bn
=
use_sync_bn
,
norm_momentum
=
norm_momentum
,
norm_epsilon
=
norm_epsilon
,
use_batch_normalization
=
use_batch_normalization
)(
x
)
feats
=
{
'1'
:
x
}
self
.
_output_specs
=
{
l
:
feats
[
l
].
get_shape
()
for
l
in
feats
}
super
(
UNet3DDecoder
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
feats
,
**
kwargs
)
def
_build_input_pyramid
(
self
,
input_specs
:
Dict
[
str
,
tf
.
TensorShape
],
depth
:
int
)
->
Dict
[
str
,
tf
.
Tensor
]:
"""Builds input pyramid features."""
assert
isinstance
(
input_specs
,
dict
)
if
len
(
input_specs
.
keys
())
>
depth
:
raise
ValueError
(
'Backbone depth should be equal to 3D UNet decoder
\'
s depth.'
)
inputs
=
{}
for
level
,
spec
in
input_specs
.
items
():
inputs
[
level
]
=
tf
.
keras
.
Input
(
shape
=
spec
[
1
:])
return
inputs
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
:
Mapping
[
str
,
Any
],
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
official/vision/beta/projects/volumetric_models/modeling/decoders/unet_3d_decoder_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D UNet decoder."""
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
from
official.vision.beta.projects.volumetric_models.modeling.decoders
import
unet_3d_decoder
class
UNet3DDecoderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
([
128
,
64
],
4
),
([
256
,
128
],
6
),
)
def
test_network_creation
(
self
,
input_size
,
model_id
):
"""Test creation of UNet3D family models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
# `input_size` consists of [spatial size, volume size].
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
[
0
],
input_size
[
0
],
input_size
[
1
],
3
),
batch_size
=
1
)
backbone
=
unet_3d
.
UNet3D
(
model_id
=
model_id
)
network
=
unet_3d_decoder
.
UNet3DDecoder
(
model_id
=
model_id
,
input_specs
=
backbone
.
output_specs
)
endpoints
=
backbone
(
inputs
)
feats
=
network
(
endpoints
)
self
.
assertIn
(
'1'
,
feats
)
self
.
assertAllEqual
([
1
,
input_size
[
0
],
input_size
[
0
],
input_size
[
1
],
64
],
feats
[
'1'
].
shape
.
as_list
())
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
kwargs
=
dict
(
model_id
=
4
,
input_specs
=
unet_3d
.
UNet3D
(
model_id
=
4
).
output_specs
,
pool_size
=
(
2
,
2
,
2
),
kernel_size
=
(
3
,
3
,
3
),
kernel_regularizer
=
None
,
activation
=
'relu'
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
use_sync_bn
=
False
,
use_batch_normalization
=
True
,
use_deconvolution
=
True
)
network
=
unet_3d_decoder
.
UNet3DDecoder
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
self
.
assertEqual
(
network
.
get_config
(),
expected_config
)
# Create another network object from the first object's config.
new_network
=
unet_3d_decoder
.
UNet3DDecoder
.
from_config
(
network
.
get_config
())
# Validate that the config can be forced to JSON.
_
=
new_network
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
network
.
get_config
(),
new_network
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/factory.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory methods to build models."""
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling.backbones
import
factory
as
backbone_factory
from
official.vision.beta.projects.volumetric_models.modeling.decoders
import
factory
as
decoder_factory
from
official.vision.beta.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
def
build_segmentation_model_3d
(
input_specs
:
tf
.
keras
.
layers
.
InputSpec
,
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
"""Builds Segmentation model."""
norm_activation_config
=
model_config
.
norm_activation
backbone
=
backbone_factory
.
build_backbone
(
input_specs
=
input_specs
,
backbone_config
=
model_config
.
backbone
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
decoder
=
decoder_factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
head_config
=
model_config
.
head
head
=
segmentation_heads_3d
.
SegmentationHead3D
(
num_classes
=
model_config
.
num_classes
,
level
=
head_config
.
level
,
num_convs
=
head_config
.
num_convs
,
num_filters
=
head_config
.
num_filters
,
upsample_factor
=
head_config
.
upsample_factor
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
,
output_logits
=
head_config
.
output_logits
)
model
=
segmentation_model
.
SegmentationModel
(
backbone
,
decoder
,
head
)
return
model
official/vision/beta/projects/volumetric_models/modeling/factory_test.py
0 → 100644
View file @
c8a91782
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(((
128
,
128
,
128
),
5e-5
),
((
64
,
64
,
64
),
None
))
def
test_unet3d_builder
(
self
,
input_size
,
weight_decay
):
num_classes
=
3
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
[
0
],
input_size
[
1
],
input_size
[
2
],
3
])
model_config
=
exp_cfg
.
SemanticSegmentationModel3D
(
num_classes
=
num_classes
)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
weight_decay
)
if
weight_decay
else
None
)
model
=
factory
.
build_segmentation_model_3d
(
input_specs
=
input_specs
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
self
.
assertIsInstance
(
model
,
tf
.
keras
.
Model
,
'Output should be a tf.keras.Model instance but got %s'
%
type
(
model
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment