Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2ee42597
Commit
2ee42597
authored
Jun 25, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jun 25, 2021
Browse files
Internal change
PiperOrigin-RevId: 381516130
parent
afb34072
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1827 additions
and
0 deletions
+1827
-0
official/vision/beta/projects/volumetric_models/modeling/factory_test.py
.../beta/projects/volumetric_models/modeling/factory_test.py
+45
-0
official/vision/beta/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
...volumetric_models/modeling/heads/segmentation_heads_3d.py
+174
-0
official/vision/beta/projects/volumetric_models/modeling/heads/segmentation_heads_3d_test.py
...etric_models/modeling/heads/segmentation_heads_3d_test.py
+59
-0
official/vision/beta/projects/volumetric_models/modeling/nn_blocks_3d.py
.../beta/projects/volumetric_models/modeling/nn_blocks_3d.py
+503
-0
official/vision/beta/projects/volumetric_models/modeling/nn_blocks_3d_test.py
.../projects/volumetric_models/modeling/nn_blocks_3d_test.py
+75
-0
official/vision/beta/projects/volumetric_models/modeling/segmentation_model_test.py
...cts/volumetric_models/modeling/segmentation_model_test.py
+78
-0
official/vision/beta/projects/volumetric_models/registry_imports.py
...ision/beta/projects/volumetric_models/registry_imports.py
+20
-0
official/vision/beta/projects/volumetric_models/serving/export_saved_model.py
.../projects/volumetric_models/serving/export_saved_model.py
+125
-0
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d.py
...cts/volumetric_models/serving/semantic_segmentation_3d.py
+57
-0
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d_test.py
...olumetric_models/serving/semantic_segmentation_3d_test.py
+109
-0
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d.py
...jects/volumetric_models/tasks/semantic_segmentation_3d.py
+345
-0
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
.../volumetric_models/tasks/semantic_segmentation_3d_test.py
+102
-0
official/vision/beta/projects/volumetric_models/train.py
official/vision/beta/projects/volumetric_models/train.py
+32
-0
official/vision/beta/projects/volumetric_models/train_test.py
...cial/vision/beta/projects/volumetric_models/train_test.py
+103
-0
No files found.
official/vision/beta/projects/volumetric_models/modeling/factory_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory.py."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(((
128
,
128
,
128
),
5e-5
),
((
64
,
64
,
64
),
None
))
def
test_unet3d_builder
(
self
,
input_size
,
weight_decay
):
num_classes
=
3
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
[
0
],
input_size
[
1
],
input_size
[
2
],
3
])
model_config
=
exp_cfg
.
SemanticSegmentationModel3D
(
num_classes
=
num_classes
)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
weight_decay
)
if
weight_decay
else
None
)
model
=
factory
.
build_segmentation_model_3d
(
input_specs
=
input_specs
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
self
.
assertIsInstance
(
model
,
tf
.
keras
.
Model
,
'Output should be a tf.keras.Model instance but got %s'
%
type
(
model
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/heads/segmentation_heads_3d.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Segmentation heads."""
from
typing
import
Any
,
Union
,
Sequence
,
Mapping
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SegmentationHead3D
(
tf
.
keras
.
layers
.
Layer
):
"""Segmentation head for 3D input."""
def
__init__
(
self
,
num_classes
:
int
,
level
:
Union
[
int
,
str
],
num_convs
:
int
=
2
,
num_filters
:
int
=
256
,
upsample_factor
:
int
=
1
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
bias_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
output_logits
:
bool
=
True
,
**
kwargs
):
"""Initialize params to build segmentation head.
Args:
num_classes: `int` number of mask classification categories. The number of
classes does not include background class.
level: `int` or `str`, level to use to build segmentation head.
num_convs: `int` number of stacked convolution before the last prediction
layer.
num_filters: `int` number to specify the number of filters used. Default
is 256.
upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied.
activation: `string`, indicating which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: `bool`, whether to use synchronized batch normalization
across different replicas.
norm_momentum: `float`, the momentum parameter of the normalization
layers.
norm_epsilon: `float`, the epsilon parameter of the normalization layers.
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
output_logits: A `bool` of whether to output logits or not. Default
is True. If set to False, output softmax.
**kwargs: other keyword arguments passed to Layer.
"""
super
(
SegmentationHead3D
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'num_classes'
:
num_classes
,
'level'
:
level
,
'num_convs'
:
num_convs
,
'num_filters'
:
num_filters
,
'upsample_factor'
:
upsample_factor
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
'output_logits'
:
output_logits
}
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
Sequence
[
tf
.
TensorShape
]]):
"""Creates the variables of the segmentation head."""
conv_op
=
tf
.
keras
.
layers
.
Conv3D
conv_kwargs
=
{
'kernel_size'
:
(
3
,
3
,
3
),
'padding'
:
'same'
,
'use_bias'
:
False
,
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
}
final_kernel_size
=
(
1
,
1
,
1
)
bn_op
=
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
if
self
.
_config_dict
[
'use_sync_bn'
]
else
tf
.
keras
.
layers
.
BatchNormalization
)
bn_kwargs
=
{
'axis'
:
self
.
_bn_axis
,
'momentum'
:
self
.
_config_dict
[
'norm_momentum'
],
'epsilon'
:
self
.
_config_dict
[
'norm_epsilon'
],
}
# Segmentation head layers.
self
.
_convs
=
[]
self
.
_norms
=
[]
for
i
in
range
(
self
.
_config_dict
[
'num_convs'
]):
conv_name
=
'segmentation_head_conv_{}'
.
format
(
i
)
self
.
_convs
.
append
(
conv_op
(
name
=
conv_name
,
filters
=
self
.
_config_dict
[
'num_filters'
],
**
conv_kwargs
))
norm_name
=
'segmentation_head_norm_{}'
.
format
(
i
)
self
.
_norms
.
append
(
bn_op
(
name
=
norm_name
,
**
bn_kwargs
))
self
.
_classifier
=
conv_op
(
name
=
'segmentation_output'
,
filters
=
self
.
_config_dict
[
'num_classes'
],
kernel_size
=
final_kernel_size
,
padding
=
'valid'
,
activation
=
None
,
bias_initializer
=
tf
.
zeros_initializer
(),
kernel_initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_regularizer
=
self
.
_config_dict
[
'kernel_regularizer'
],
bias_regularizer
=
self
.
_config_dict
[
'bias_regularizer'
])
super
(
SegmentationHead3D
,
self
).
build
(
input_shape
)
def
call
(
self
,
backbone_output
:
Mapping
[
str
,
tf
.
Tensor
],
decoder_output
:
Mapping
[
str
,
tf
.
Tensor
])
->
tf
.
Tensor
:
"""Forward pass of the segmentation head.
Args:
backbone_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
decoder_output: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is [batch,
height_l, width_l, channels].
Returns:
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
"""
x
=
decoder_output
[
str
(
self
.
_config_dict
[
'level'
])]
for
conv
,
norm
in
zip
(
self
.
_convs
,
self
.
_norms
):
x
=
conv
(
x
)
x
=
norm
(
x
)
x
=
self
.
_activation
(
x
)
x
=
tf
.
keras
.
layers
.
UpSampling3D
(
size
=
self
.
_config_dict
[
'upsample_factor'
])(
x
)
x
=
self
.
_classifier
(
x
)
return
x
if
self
.
_config_dict
[
'output_logits'
]
else
tf
.
keras
.
layers
.
Softmax
(
dtype
=
'float32'
)(
x
)
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
:
Mapping
[
str
,
Any
]):
return
cls
(
**
config
)
official/vision/beta/projects/volumetric_models/modeling/heads/segmentation_heads_3d_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for segmentation_heads.py."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
class
SegmentationHead3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
1
,
0
),
(
2
,
1
),
)
def
test_forward
(
self
,
level
,
num_convs
):
head
=
segmentation_heads_3d
.
SegmentationHead3D
(
num_classes
=
10
,
level
=
level
,
num_convs
=
num_convs
)
backbone_features
=
{
'1'
:
np
.
random
.
rand
(
2
,
128
,
128
,
128
,
16
),
'2'
:
np
.
random
.
rand
(
2
,
64
,
64
,
64
,
16
),
}
decoder_features
=
{
'1'
:
np
.
random
.
rand
(
2
,
128
,
128
,
128
,
16
),
'2'
:
np
.
random
.
rand
(
2
,
64
,
64
,
64
,
16
),
}
logits
=
head
(
backbone_features
,
decoder_features
)
if
str
(
level
)
in
decoder_features
:
self
.
assertAllEqual
(
logits
.
numpy
().
shape
,
[
2
,
decoder_features
[
str
(
level
)].
shape
[
1
],
decoder_features
[
str
(
level
)].
shape
[
2
],
decoder_features
[
str
(
level
)].
shape
[
3
],
10
])
def
test_serialize_deserialize
(
self
):
head
=
segmentation_heads_3d
.
SegmentationHead3D
(
num_classes
=
10
,
level
=
3
)
config
=
head
.
get_config
()
new_head
=
segmentation_heads_3d
.
SegmentationHead3D
.
from_config
(
config
)
self
.
assertAllEqual
(
head
.
get_config
(),
new_head
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/nn_blocks_3d.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common building blocks for neural networks."""
from
typing
import
Sequence
,
Union
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.layers
import
nn_layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BasicBlock3DVolume
(
tf
.
keras
.
layers
.
Layer
):
"""A basic 3d convolution block."""
def
__init__
(
self
,
filters
:
Union
[
int
,
Sequence
[
int
]],
strides
:
Union
[
int
,
Sequence
[
int
]],
kernel_size
:
Union
[
int
,
Sequence
[
int
]],
kernel_initializer
:
str
=
'VarianceScaling'
,
kernel_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
bias_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
use_batch_normalization
:
bool
=
False
,
**
kwargs
):
"""Creates a basic 3d convolution block applying one or more convolutions.
Args:
filters: A list of `int` numbers or an `int` number of filters. Given an
`int` input, a single convolution is applied; otherwise a series of
convolutions are applied.
strides: An integer or tuple/list of 3 integers, specifying the strides of
the convolution along each spatial dimension. Can be a single integer to
specify the same value for all spatial dimensions.
kernel_size: An integer or tuple/list of 3 integers, specifying the depth,
height and width of the 3D convolution window. Can be a single integer
to specify the same value for all spatial dimensions.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
use_batch_normalization: Wheher to use batch normalizaion or not.
**kwargs: keyword arguments to be passed.
"""
super
().
__init__
(
**
kwargs
)
if
isinstance
(
filters
,
int
):
self
.
_filters
=
[
filters
]
else
:
self
.
_filters
=
filters
self
.
_strides
=
strides
self
.
_kernel_size
=
kernel_size
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_activation
=
activation
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_use_batch_normalization
=
use_batch_normalization
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
:
tf
.
TensorShape
):
"""Builds the basic 3d convolution block."""
self
.
_convs
=
[]
self
.
_norms
=
[]
for
filters
in
self
.
_filters
:
self
.
_convs
.
append
(
tf
.
keras
.
layers
.
Conv3D
(
filters
=
filters
,
kernel_size
=
self
.
_kernel_size
,
strides
=
self
.
_strides
,
padding
=
'same'
,
data_format
=
tf
.
keras
.
backend
.
image_data_format
(),
activation
=
None
))
self
.
_norms
.
append
(
self
.
_norm
(
axis
=
self
.
_bn_axis
))
super
(
BasicBlock3DVolume
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
"""Returns the config of the basic 3d convolution block."""
config
=
{
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'kernel_size'
:
self
.
_kernel_size
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
,
'use_batch_normalization'
:
self
.
_use_batch_normalization
}
base_config
=
super
(
BasicBlock3DVolume
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
:
tf
.
Tensor
,
training
:
bool
=
None
)
->
tf
.
Tensor
:
"""Runs forward pass on the input tensor."""
x
=
inputs
for
conv
,
norm
in
zip
(
self
.
_convs
,
self
.
_norms
):
x
=
conv
(
x
)
if
self
.
_use_batch_normalization
:
x
=
norm
(
x
)
x
=
self
.
_activation_fn
(
x
)
return
x
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
ResidualBlock3DVolume
(
tf
.
keras
.
layers
.
Layer
):
"""A residual 3d block."""
def
__init__
(
self
,
filters
,
strides
,
use_projection
=
False
,
se_ratio
=
None
,
stochastic_depth_drop_rate
=
None
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""A residual 3d block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super
().
__init__
(
**
kwargs
)
self
.
_filters
=
filters
self
.
_strides
=
strides
self
.
_use_projection
=
use_projection
self
.
_se_ratio
=
se_ratio
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_kernel_initializer
=
kernel_initializer
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
,
kernel_size
=
1
,
strides
=
self
.
_strides
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
strides
=
self
.
_strides
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
if
self
.
_se_ratio
and
self
.
_se_ratio
>
0
and
self
.
_se_ratio
<=
1
:
self
.
_squeeze_excitation
=
nn_layers
.
SqueezeExcitation
(
in_filters
=
self
.
_filters
,
out_filters
=
self
.
_filters
,
se_ratio
=
self
.
_se_ratio
,
use_3d_input
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
else
:
self
.
_squeeze_excitation
=
None
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
nn_layers
.
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
None
super
(
ResidualBlock3DVolume
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'use_projection'
:
self
.
_use_projection
,
'se_ratio'
:
self
.
_se_ratio
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
ResidualBlock3DVolume
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
shortcut
=
inputs
if
self
.
_use_projection
:
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_norm1
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_norm2
(
x
)
if
self
.
_squeeze_excitation
:
x
=
self
.
_squeeze_excitation
(
x
)
if
self
.
_stochastic_depth
:
x
=
self
.
_stochastic_depth
(
x
,
training
=
training
)
return
self
.
_activation_fn
(
x
+
shortcut
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
BottleneckBlock3DVolume
(
tf
.
keras
.
layers
.
Layer
):
"""A standard bottleneck block."""
def
__init__
(
self
,
filters
,
strides
,
dilation_rate
=
1
,
use_projection
=
False
,
se_ratio
=
None
,
stochastic_depth_drop_rate
=
None
,
kernel_initializer
=
'VarianceScaling'
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
**
kwargs
):
"""A standard bottleneck 3d block with BN after convolutions.
Args:
filters: `int` number of filters for the first two convolutions. Note that
the third and final convolution will use 4 times as many filters.
strides: `int` block stride. If greater than 1, this block will ultimately
downsample the input.
dilation_rate: `int` dilation_rate of convolutions. Default to 1.
use_projection: `bool` for whether this block should use a projection
shortcut (versus the default identity shortcut). This is usually `True`
for the first block of a block group, which may change the number of
filters and the resolution.
se_ratio: `float` or None. Ratio of the Squeeze-and-Excitation layer.
stochastic_depth_drop_rate: `float` or None. if not None, drop rate for
the stochastic depth layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
activation: `str` name of the activation function.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super
().
__init__
(
**
kwargs
)
self
.
_filters
=
filters
self
.
_strides
=
strides
self
.
_dilation_rate
=
dilation_rate
self
.
_use_projection
=
use_projection
self
.
_se_ratio
=
se_ratio
self
.
_use_sync_bn
=
use_sync_bn
self
.
_activation
=
activation
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
self
.
_kernel_initializer
=
kernel_initializer
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
if
self
.
_use_projection
:
self
.
_shortcut
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
*
4
,
kernel_size
=
1
,
strides
=
self
.
_strides
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv1
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm1
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv2
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
,
kernel_size
=
3
,
strides
=
self
.
_strides
,
dilation_rate
=
self
.
_dilation_rate
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm2
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
self
.
_conv3
=
tf
.
keras
.
layers
.
Conv3D
(
filters
=
self
.
_filters
*
4
,
kernel_size
=
1
,
strides
=
1
,
use_bias
=
False
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
self
.
_norm3
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
)
if
self
.
_se_ratio
and
self
.
_se_ratio
>
0
and
self
.
_se_ratio
<=
1
:
self
.
_squeeze_excitation
=
nn_layers
.
SqueezeExcitation
(
in_filters
=
self
.
_filters
*
4
,
out_filters
=
self
.
_filters
*
4
,
se_ratio
=
self
.
_se_ratio
,
use_3d_input
=
True
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
)
else
:
self
.
_squeeze_excitation
=
None
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
nn_layers
.
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
None
super
(
BottleneckBlock3DVolume
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
'filters'
:
self
.
_filters
,
'strides'
:
self
.
_strides
,
'dilation_rate'
:
self
.
_dilation_rate
,
'use_projection'
:
self
.
_use_projection
,
'se_ratio'
:
self
.
_se_ratio
,
'stochastic_depth_drop_rate'
:
self
.
_stochastic_depth_drop_rate
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
BottleneckBlock3DVolume
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
shortcut
=
inputs
if
self
.
_use_projection
:
shortcut
=
self
.
_shortcut
(
shortcut
)
shortcut
=
self
.
_norm0
(
shortcut
)
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_norm1
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_norm2
(
x
)
x
=
self
.
_activation_fn
(
x
)
x
=
self
.
_conv3
(
x
)
x
=
self
.
_norm3
(
x
)
if
self
.
_squeeze_excitation
:
x
=
self
.
_squeeze_excitation
(
x
)
if
self
.
_stochastic_depth
:
x
=
self
.
_stochastic_depth
(
x
,
training
=
training
)
return
self
.
_activation_fn
(
x
+
shortcut
)
official/vision/beta/projects/volumetric_models/modeling/nn_blocks_3d_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for 3D volumeric convoluion blocks."""
# Import libraries
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling
import
nn_blocks_3d
class
NNBlocks3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
((
128
,
128
,
32
,
1
),
(
256
,
256
,
16
,
2
))
def
test_bottleneck_block_3d_volume_creation
(
self
,
spatial_size
,
volume_size
,
filters
,
strides
):
inputs
=
tf
.
keras
.
Input
(
shape
=
(
spatial_size
,
spatial_size
,
volume_size
,
filters
*
4
),
batch_size
=
1
)
block
=
nn_blocks_3d
.
BottleneckBlock3DVolume
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
)
features
=
block
(
inputs
)
self
.
assertAllEqual
([
1
,
spatial_size
//
strides
,
spatial_size
//
strides
,
volume_size
//
strides
,
filters
*
4
],
features
.
shape
.
as_list
())
@
parameterized
.
parameters
((
128
,
128
,
32
,
1
),
(
256
,
256
,
64
,
2
))
def
test_residual_block_3d_volume_creation
(
self
,
spatial_size
,
volume_size
,
filters
,
strides
):
inputs
=
tf
.
keras
.
Input
(
shape
=
(
spatial_size
,
spatial_size
,
volume_size
,
filters
),
batch_size
=
1
)
block
=
nn_blocks_3d
.
ResidualBlock3DVolume
(
filters
=
filters
,
strides
=
strides
,
use_projection
=
True
)
features
=
block
(
inputs
)
self
.
assertAllEqual
([
1
,
spatial_size
//
strides
,
spatial_size
//
strides
,
volume_size
//
strides
,
filters
],
features
.
shape
.
as_list
())
@
parameterized
.
parameters
((
128
,
128
,
64
,
1
,
3
),
(
256
,
256
,
128
,
2
,
1
))
def
test_basic_block_3d_volume_creation
(
self
,
spatial_size
,
volume_size
,
filters
,
strides
,
kernel_size
):
inputs
=
tf
.
keras
.
Input
(
shape
=
(
spatial_size
,
spatial_size
,
volume_size
,
filters
),
batch_size
=
1
)
block
=
nn_blocks_3d
.
BasicBlock3DVolume
(
filters
=
filters
,
strides
=
strides
,
kernel_size
=
kernel_size
)
features
=
block
(
inputs
)
self
.
assertAllEqual
([
1
,
spatial_size
//
strides
,
spatial_size
//
strides
,
volume_size
//
strides
,
filters
],
features
.
shape
.
as_list
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/segmentation_model_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for segmentation network."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.modeling.heads
import
segmentation_heads_3d
class
SegmentationNetworkUNet3DTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
([
32
,
32
],
4
),
([
64
,
64
],
4
),
([
64
,
64
],
2
),
([
128
,
64
],
2
),
)
def
test_segmentation_network_unet3d_creation
(
self
,
input_size
,
depth
):
"""Test for creation of a segmentation network."""
num_classes
=
2
inputs
=
np
.
random
.
rand
(
2
,
input_size
[
0
],
input_size
[
0
],
input_size
[
1
],
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
backbones
.
UNet3D
(
model_id
=
depth
)
decoder
=
decoders
.
UNet3DDecoder
(
model_id
=
depth
,
input_specs
=
backbone
.
output_specs
)
head
=
segmentation_heads_3d
.
SegmentationHead3D
(
num_classes
,
level
=
1
,
num_convs
=
0
)
model
=
segmentation_model
.
SegmentationModel
(
backbone
=
backbone
,
decoder
=
decoder
,
head
=
head
)
logits
=
model
(
inputs
)
self
.
assertAllEqual
(
[
2
,
input_size
[
0
],
input_size
[
0
],
input_size
[
1
],
num_classes
],
logits
.
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
"""Validate the network can be serialized and deserialized."""
num_classes
=
3
backbone
=
backbones
.
UNet3D
(
model_id
=
4
)
decoder
=
decoders
.
UNet3DDecoder
(
model_id
=
4
,
input_specs
=
backbone
.
output_specs
)
head
=
segmentation_heads_3d
.
SegmentationHead3D
(
num_classes
,
level
=
1
,
num_convs
=
0
)
model
=
segmentation_model
.
SegmentationModel
(
backbone
=
backbone
,
decoder
=
decoder
,
head
=
head
)
config
=
model
.
get_config
()
new_model
=
segmentation_model
.
SegmentationModel
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
model
.
get_config
(),
new_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/registry_imports.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
official/vision/beta/projects/volumetric_models/serving/export_saved_model.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Volumetric model export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=1 \
--input_image_size=128,128,128 \
--num_channels=1
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from
absl
import
app
from
absl
import
flags
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.vision.beta.projects.volumetric_models.serving
import
semantic_segmentation_3d
from
official.vision.beta.serving
import
export_saved_model_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
'experiment'
,
None
,
'experiment type, e.g. retinanet_resnetfpn_coco'
)
flags
.
DEFINE_string
(
'export_dir'
,
None
,
'The export directory.'
)
flags
.
DEFINE_string
(
'checkpoint_path'
,
None
,
'Checkpoint path.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
None
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
''
,
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.'
)
flags
.
DEFINE_integer
(
'batch_size'
,
None
,
'The batch size.'
)
flags
.
DEFINE_string
(
'input_type'
,
'image_tensor'
,
'One of `image_tensor`, `image_bytes`, `tf_example`.'
)
flags
.
DEFINE_list
(
'input_image_size'
,
None
,
'The comma-separated string of three integers representing the '
'height, width and depth of the input to the model.'
)
flags
.
DEFINE_integer
(
'num_channels'
,
1
,
'The number of channels of input image.'
)
flags
.
register_validator
(
'input_image_size'
,
lambda
value
:
value
is
not
None
and
len
(
value
)
==
3
,
message
=
'--input_image_size must be comma-separated string of three '
'integers representing the height, width and depth of the input to '
'the model.'
)
def
main
(
_
):
flags
.
mark_flag_as_required
(
'export_dir'
)
flags
.
mark_flag_as_required
(
'checkpoint_path'
)
params
=
exp_factory
.
get_exp_config
(
FLAGS
.
experiment
)
for
config_file
in
FLAGS
.
config_file
or
[]:
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
if
FLAGS
.
params_override
:
params
=
hyperparams
.
override_params_dict
(
params
,
FLAGS
.
params_override
,
is_strict
=
True
)
params
.
validate
()
params
.
lock
()
input_image_size
=
FLAGS
.
input_image_size
export_module
=
semantic_segmentation_3d
.
SegmentationModule
(
params
=
params
,
batch_size
=
1
,
input_image_size
=
input_image_size
,
num_channels
=
FLAGS
.
num_channels
)
export_saved_model_lib
.
export_inference_graph
(
input_type
=
FLAGS
.
input_type
,
batch_size
=
FLAGS
.
batch_size
,
input_image_size
=
input_image_size
,
params
=
params
,
checkpoint_path
=
FLAGS
.
checkpoint_path
,
export_dir
=
FLAGS
.
export_dir
,
num_channels
=
FLAGS
.
num_channels
,
export_module
=
export_module
,
export_checkpoint_subdir
=
'checkpoint'
,
export_saved_model_subdir
=
'saved_model'
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""3D semantic segmentation input and model functions for serving/inference."""
from
typing
import
Mapping
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
from
official.vision.beta.serving
import
export_base
class
SegmentationModule
(
export_base
.
ExportModule
):
"""Segmentation Module."""
def
_build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds and returns a segmentation model."""
num_channels
=
self
.
params
.
task
.
model
.
num_channels
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
self
.
_batch_size
]
+
self
.
_input_image_size
+
[
num_channels
])
return
factory
.
build_segmentation_model_3d
(
input_specs
=
input_specs
,
model_config
=
self
.
params
.
task
.
model
,
l2_regularizer
=
None
)
def
serve
(
self
,
images
:
tf
.
Tensor
)
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Casts an image tensor to float and runs inference.
Args:
images: A uint8 tf.Tensor of shape [batch_size, None, None, None,
num_channels].
Returns:
A dictionary holding segmentation outputs.
"""
with
tf
.
device
(
'cpu:0'
):
images
=
tf
.
cast
(
images
,
dtype
=
tf
.
float32
)
outputs
=
self
.
inference_step
(
images
)
output_key
=
'logits'
if
self
.
params
.
task
.
model
.
head
.
output_logits
else
'probs'
return
{
output_key
:
outputs
}
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for semantic_segmentation_3d export lib."""
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
# pylint: disable=unused-import
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
from
official.vision.beta.projects.volumetric_models.serving
import
semantic_segmentation_3d
class
SemanticSegmentationExportTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_num_channels
=
2
self
.
_input_image_size
=
[
32
,
32
,
32
]
self
.
_params
=
exp_factory
.
get_exp_config
(
'seg_unet3d_test'
)
input_shape
=
self
.
_input_image_size
+
[
self
.
_num_channels
]
self
.
_image_array
=
np
.
zeros
(
shape
=
input_shape
,
dtype
=
np
.
uint8
)
def
_get_segmentation_module
(
self
):
return
semantic_segmentation_3d
.
SegmentationModule
(
self
.
_params
,
batch_size
=
1
,
input_image_size
=
self
.
_input_image_size
,
num_channels
=
self
.
_num_channels
)
def
_export_from_module
(
self
,
module
,
input_type
:
str
,
save_directory
:
str
):
signatures
=
module
.
get_inference_signatures
(
{
input_type
:
'serving_default'
})
tf
.
saved_model
.
save
(
module
,
save_directory
,
signatures
=
signatures
)
def
_get_dummy_input
(
self
,
input_type
):
"""Get dummy input for the given input type."""
if
input_type
==
'image_tensor'
:
image_tensor
=
tf
.
convert_to_tensor
(
self
.
_image_array
,
dtype
=
tf
.
uint8
)
return
tf
.
expand_dims
(
image_tensor
,
axis
=
0
)
if
input_type
==
'image_bytes'
:
return
[
self
.
_image_array
.
tostring
()]
if
input_type
==
'tf_example'
:
encoded_image
=
self
.
_image_array
.
tostring
()
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
encoded_image
])),
})).
SerializeToString
()
return
[
example
]
@
parameterized
.
parameters
(
{
'input_type'
:
'image_tensor'
},
{
'input_type'
:
'image_bytes'
},
{
'input_type'
:
'tf_example'
},
)
def
test_export
(
self
,
input_type
:
str
=
'image_tensor'
):
tmp_dir
=
self
.
get_temp_dir
()
module
=
self
.
_get_segmentation_module
()
self
.
_export_from_module
(
module
,
input_type
,
tmp_dir
)
# Check if model is successfully exported.
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
tmp_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
# Get inference signature from loaded SavedModel.
imported
=
tf
.
saved_model
.
load
(
tmp_dir
)
segmentation_fn
=
imported
.
signatures
[
'serving_default'
]
images
=
self
.
_get_dummy_input
(
input_type
)
image_tensor
=
self
.
_get_dummy_input
(
input_type
=
'image_tensor'
)
# Perform inference using loaded SavedModel and model instance and check if
# outputs equal.
expected_output
=
module
.
model
(
image_tensor
,
training
=
False
)
out
=
segmentation_fn
(
tf
.
constant
(
images
))
self
.
assertAllClose
(
out
[
'logits'
].
numpy
(),
expected_output
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image segmentation task definition."""
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
,
Union
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
dataset_fn
from
official.core
import
base_task
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.vision.beta.projects.volumetric_models.dataloaders
import
segmentation_input_3d
from
official.vision.beta.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.vision.beta.projects.volumetric_models.losses
import
segmentation_losses
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
@
task_factory
.
register_task_cls
(
exp_cfg
.
SemanticSegmentation3DTask
)
class
SemanticSegmentation3DTask
(
base_task
.
Task
):
"""A task for semantic segmentation."""
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""Builds segmentation model."""
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
self
.
task_config
.
model
.
input_size
+
[
self
.
task_config
.
model
.
num_channels
],
dtype
=
self
.
task_config
.
train_data
.
dtype
)
l2_weight_decay
=
self
.
task_config
.
losses
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
model
=
factory
.
build_segmentation_model_3d
(
input_specs
=
input_specs
,
model_config
=
self
.
task_config
.
model
,
l2_regularizer
=
l2_regularizer
)
# Create a dummy input and call model instance to initialize the model. This
# is needed when launching multiple experiments using the same model
# directory. Since there is already a trained model, forward pass will not
# run and the model will never be built. This is only done when spatial
# partitioning is not enabled; otherwise it will fail with OOM due to
# extremely large input.
if
(
not
self
.
task_config
.
train_input_partition_dims
)
and
(
not
self
.
task_config
.
eval_input_partition_dims
):
dummy_input
=
tf
.
random
.
uniform
(
shape
=
[
1
]
+
list
(
input_specs
.
shape
[
1
:]))
_
=
model
(
dummy_input
)
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loads pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
'all'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
else
:
ckpt_items
=
{}
if
'backbone'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
backbone
=
model
.
backbone
)
if
'decoder'
in
self
.
task_config
.
init_checkpoint_modules
:
ckpt_items
.
update
(
decoder
=
model
.
decoder
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
ckpt_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
)
->
tf
.
data
.
Dataset
:
"""Builds classification input."""
decoder
=
segmentation_input_3d
.
Decoder
(
image_field_key
=
params
.
image_field_key
,
label_field_key
=
params
.
label_field_key
)
parser
=
segmentation_input_3d
.
Parser
(
input_size
=
params
.
input_size
,
num_classes
=
params
.
num_classes
,
num_channels
=
params
.
num_channels
,
image_field_key
=
params
.
image_field_key
,
label_field_key
=
params
.
label_field_key
,
dtype
=
params
.
dtype
,
label_dtype
=
params
.
label_dtype
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
aux_losses
=
None
)
->
tf
.
Tensor
:
"""Segmentation loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
segmentation_loss_fn
=
segmentation_losses
.
SegmentationLossDiceScore
(
metric_type
=
'adaptive'
)
total_loss
=
segmentation_loss_fn
(
model_outputs
,
labels
)
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
def
build_metrics
(
self
,
training
:
bool
=
True
)
->
Sequence
[
tf
.
keras
.
metrics
.
Metric
]:
"""Gets streaming metrics for training/validation."""
metrics
=
[]
num_classes
=
self
.
task_config
.
model
.
num_classes
if
training
:
metrics
.
extend
([
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'train_categorical_accuracy'
,
dtype
=
tf
.
float32
)
])
else
:
self
.
metrics
=
[
segmentation_metrics
.
DiceScore
(
num_classes
=
num_classes
,
metric_type
=
'generalized'
,
per_class_metric
=
self
.
task_config
.
evaluation
.
report_per_class_metric
,
name
=
'val_generalized_dice'
,
dtype
=
tf
.
float32
)
]
return
metrics
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
Sequence
[
tf
.
keras
.
metrics
.
Metric
]]
=
None
)
->
Dict
[
Any
,
Any
]:
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
train_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
=
strategy
.
experimental_split_to_logical_devices
(
features
,
input_partition_dims
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
if
self
.
task_config
.
model
.
head
.
output_logits
:
outputs
=
tf
.
nn
.
softmax
(
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
# Compute all metrics within strategy scope for training.
if
metrics
:
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
outputs
=
tf
.
cast
(
outputs
,
tf
.
float32
)
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
Sequence
[
tf
.
keras
.
metrics
.
Metric
]]
=
None
)
->
Dict
[
Any
,
Any
]:
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
input_partition_dims
=
self
.
task_config
.
eval_input_partition_dims
if
input_partition_dims
:
strategy
=
tf
.
distribute
.
get_strategy
()
features
=
strategy
.
experimental_split_to_logical_devices
(
features
,
input_partition_dims
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
if
self
.
task_config
.
model
.
head
.
output_logits
:
outputs
=
tf
.
nn
.
softmax
(
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
# Compute dice score metrics on CPU.
for
metric
in
self
.
metrics
:
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
outputs
=
tf
.
cast
(
outputs
,
tf
.
float32
)
logs
.
update
({
metric
.
name
:
(
labels
,
outputs
)})
return
logs
def
inference_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
)
->
tf
.
Tensor
:
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
:
Optional
[
Sequence
[
Union
[
segmentation_metrics
.
DiceScore
,
tf
.
keras
.
metrics
.
Metric
]]]
=
None
,
step_outputs
:
Optional
[
Mapping
[
str
,
Any
]]
=
None
)
->
Sequence
[
tf
.
keras
.
metrics
.
Metric
]:
"""Aggregates statistics to compute metrics over training.
Args:
state: A sequence of tf.keras.metrics.Metric objects. Each element records
a metric.
step_outputs: A dictionary of [metric_name, (labels, output)] from a step.
Returns:
An updated sequence of tf.keras.metrics.Metric objects.
"""
if
state
is
None
:
for
metric
in
self
.
metrics
:
metric
.
reset_states
()
state
=
self
.
metrics
for
metric
in
self
.
metrics
:
labels
=
step_outputs
[
metric
.
name
][
0
]
predictions
=
step_outputs
[
metric
.
name
][
1
]
# If `step_output` is distributed, it contains a tuple of Tensors instead
# of a single Tensor, so we need to concatenate them along the batch
# dimension in this case to have a single Tensor.
if
isinstance
(
labels
,
tuple
):
labels
=
tf
.
concat
(
list
(
labels
),
axis
=
0
)
if
isinstance
(
predictions
,
tuple
):
predictions
=
tf
.
concat
(
list
(
predictions
),
axis
=
0
)
labels
=
tf
.
cast
(
labels
,
tf
.
float32
)
predictions
=
tf
.
cast
(
predictions
,
tf
.
float32
)
metric
.
update_state
(
labels
,
predictions
)
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
:
Optional
[
Mapping
[
str
,
Any
]]
=
None
,
global_step
:
Optional
[
tf
.
Tensor
]
=
None
)
->
Mapping
[
str
,
float
]:
"""Reduces logs to obtain per-class metrics if needed.
Args:
aggregated_logs: An optional dictionary containing aggregated logs.
global_step: An optional `tf.Tensor` of current global training steps.
Returns:
The reduced logs containing per-class metrics and overall metrics.
Raises:
ValueError: If `self.metrics` does not contain exactly 1 metric object.
"""
result
=
{}
if
len
(
self
.
metrics
)
!=
1
:
raise
ValueError
(
'Exact one metric must be present, but {0} are '
'present.'
.
format
(
len
(
self
.
metrics
)))
metric
=
self
.
metrics
[
0
].
result
().
numpy
()
if
self
.
task_config
.
evaluation
.
report_per_class_metric
:
for
i
,
metric_val
in
enumerate
(
metric
):
metric_name
=
self
.
metrics
[
0
].
name
+
'/class_{0}'
.
format
(
i
-
1
)
if
i
>
0
else
self
.
metrics
[
0
].
name
result
.
update
({
metric_name
:
metric_val
})
return
result
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for semantic segmentation task."""
# pylint: disable=unused-import
import
functools
import
os
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
as
img_seg_task
class
SemanticSegmentationTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
# pylint: disable=g-complex-comprehension
examples
=
[
tfexample_utils
.
create_3d_image_test_example
(
image_height
=
32
,
image_width
=
32
,
image_volume
=
32
,
image_channel
=
2
)
for
_
in
range
(
20
)
]
# pylint: enable=g-complex-comprehension
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
@
parameterized
.
parameters
((
'seg_unet3d_test'
,))
def
test_task
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
config
.
task
.
train_data
.
input_path
=
self
.
_data_path
config
.
task
.
train_data
.
global_batch_size
=
4
config
.
task
.
train_data
.
shuffle_buffer_size
=
4
config
.
task
.
validation_data
.
input_path
=
self
.
_data_path
config
.
task
.
validation_data
.
shuffle_buffer_size
=
4
config
.
task
.
evaluation
.
report_per_class_metric
=
True
task
=
img_seg_task
.
SemanticSegmentation3DTask
(
config
.
task
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
task
.
train_data
)
iterator
=
iter
(
dataset
)
opt_factory
=
optimization
.
OptimizerFactory
(
config
.
trainer
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
logs
=
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
# Check if training loss is produced.
self
.
assertIn
(
'loss'
,
logs
)
# Obtain distributed outputs.
distributed_outputs
=
strategy
.
run
(
functools
.
partial
(
task
.
validation_step
,
model
=
model
,
metrics
=
task
.
build_metrics
(
training
=
False
)),
args
=
(
next
(
iterator
),))
outputs
=
tf
.
nest
.
map_structure
(
strategy
.
experimental_local_results
,
distributed_outputs
)
# Check if validation loss is produced.
self
.
assertIn
(
'loss'
,
outputs
)
# Check if state is updated.
state
=
task
.
aggregate_logs
(
state
=
None
,
step_outputs
=
outputs
)
self
.
assertLen
(
state
,
1
)
self
.
assertIsInstance
(
state
[
0
],
segmentation_metrics
.
DiceScore
)
# Check if all metrics are produced.
result
=
task
.
reduce_aggregated_logs
(
aggregated_logs
=
{},
global_step
=
1
)
self
.
assertIn
(
'val_generalized_dice'
,
result
)
self
.
assertIn
(
'val_generalized_dice/class_0'
,
result
)
self
.
assertIn
(
'val_generalized_dice/class_1'
,
result
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/train.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver."""
from
absl
import
app
import
gin
# pylint: disable=unused-import
from
official.common
import
flags
as
tfm_flags
from
official.vision.beta
import
train
from
official.vision.beta.projects.volumetric_models
import
registry_imports
# pylint: disable=unused-import
def
main
(
_
):
train
.
main
(
_
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
official/vision/beta/projects/volumetric_models/train_test.py
0 → 100644
View file @
2ee42597
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train."""
import
json
import
os
from
absl
import
flags
from
absl
import
logging
from
absl.testing
import
flagsaver
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.volumetric_models
import
train
as
train_lib
FLAGS
=
flags
.
FLAGS
class
TrainTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_model_dir
)
data_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'data'
)
tf
.
io
.
gfile
.
makedirs
(
data_dir
)
self
.
_data_path
=
os
.
path
.
join
(
data_dir
,
'data.tfrecord'
)
# pylint: disable=g-complex-comprehension
examples
=
[
tfexample_utils
.
create_3d_image_test_example
(
image_height
=
32
,
image_width
=
32
,
image_volume
=
32
,
image_channel
=
2
)
for
_
in
range
(
2
)
]
# pylint: enable=g-complex-comprehension
tfexample_utils
.
dump_to_tfrecord
(
self
.
_data_path
,
tf_examples
=
examples
)
def
test_run
(
self
):
saved_flag_values
=
flagsaver
.
save_flag_values
()
train_lib
.
tfm_flags
.
define_flags
()
FLAGS
.
mode
=
'train'
FLAGS
.
model_dir
=
self
.
_model_dir
FLAGS
.
experiment
=
'seg_unet3d_test'
logging
.
info
(
'Test pipeline correctness.'
)
params_override
=
json
.
dumps
({
'runtime'
:
{
'mixed_precision_dtype'
:
'float32'
,
},
'trainer'
:
{
'train_steps'
:
1
,
'validation_steps'
:
1
,
},
'task'
:
{
'model'
:
{
'backbone'
:
{
'unet_3d'
:
{
'model_id'
:
4
,
},
},
'decoder'
:
{
'unet_3d_decoder'
:
{
'model_id'
:
4
,
},
},
},
'train_data'
:
{
'input_path'
:
self
.
_data_path
,
'file_type'
:
'tfrecord'
,
'global_batch_size'
:
2
,
},
'validation_data'
:
{
'input_path'
:
self
.
_data_path
,
'file_type'
:
'tfrecord'
,
'global_batch_size'
:
2
,
}
}
})
FLAGS
.
params_override
=
params_override
train_lib
.
main
(
'unused_args'
)
FLAGS
.
mode
=
'eval'
with
train_lib
.
gin
.
unlock_config
():
train_lib
.
main
(
'unused_args'
)
flagsaver
.
restore_flag_values
(
saved_flag_values
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment