Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
dbaec326
Commit
dbaec326
authored
Jul 09, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jul 09, 2021
Browse files
Refactor decoder factory to allow registering other decoders.
PiperOrigin-RevId: 383944185
parent
bc71d8e9
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
221 additions
and
38 deletions
+221
-38
official/vision/beta/projects/volumetric_models/modeling/decoders/factory.py
...a/projects/volumetric_models/modeling/decoders/factory.py
+88
-32
official/vision/beta/projects/volumetric_models/modeling/decoders/factory_test.py
...jects/volumetric_models/modeling/decoders/factory_test.py
+80
-0
official/vision/beta/projects/volumetric_models/modeling/decoders/unet_3d_decoder.py
...ts/volumetric_models/modeling/decoders/unet_3d_decoder.py
+40
-1
official/vision/beta/projects/volumetric_models/modeling/factory_test.py
.../beta/projects/volumetric_models/modeling/factory_test.py
+3
-1
official/vision/beta/projects/volumetric_models/registry_imports.py
...ision/beta/projects/volumetric_models/registry_imports.py
+1
-0
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d.py
...cts/volumetric_models/serving/semantic_segmentation_3d.py
+3
-1
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d_test.py
...olumetric_models/serving/semantic_segmentation_3d_test.py
+4
-2
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
.../volumetric_models/tasks/semantic_segmentation_3d_test.py
+2
-1
No files found.
official/vision/beta/projects/volumetric_models/modeling/decoders/factory.py
View file @
dbaec326
...
@@ -12,49 +12,105 @@
...
@@ -12,49 +12,105 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Decoder registers and factory method.
"""factory method."""
One can register a new decoder model by the following two steps:
1 Import the factory and register the build in the decoder file.
2 Import the decoder class and add a build in __init__.py.
```
# my_decoder.py
from modeling.decoders import factory
class MyDecoder():
...
@factory.register_decoder_builder('my_decoder')
def build_my_decoder():
return MyDecoder()
# decoders/__init__.py adds import
from modeling.decoders.my_decoder import MyDecoder
```
If one wants the MyDecoder class to be used only by those binary
then don't imported the decoder module in decoders/__init__.py, but import it
in place that uses it.
"""
from
typing
import
Union
,
Mapping
,
Optional
# Import libraries
from
typing
import
Mapping
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.core
import
registry
from
official.modeling
import
hyperparams
_REGISTERED_DECODER_CLS
=
{}
def
register_decoder_builder
(
key
:
str
):
"""Decorates a builder of decoder class.
The builder should be a Callable (a class or a function).
This decorator supports registration of decoder builder as follows:
```
class MyDecoder(tf.keras.Model):
pass
@register_decoder_builder('mydecoder')
def builder(input_specs, config, l2_reg):
return MyDecoder(...)
# Builds a MyDecoder object.
my_decoder = build_decoder_3d(input_specs, config, l2_reg)
```
Args:
key: A `str` of key to look up the builder.
Returns:
A callable for using as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return
registry
.
register
(
_REGISTERED_DECODER_CLS
,
key
)
@
register_decoder_builder
(
'identity'
)
def
build_identity
(
input_specs
:
Optional
[
Mapping
[
str
,
tf
.
TensorShape
]]
=
None
,
model_config
:
Optional
[
hyperparams
.
Config
]
=
None
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
None
:
del
input_specs
,
model_config
,
l2_regularizer
# Unused by identity decoder.
return
None
def
build_decoder
(
def
build_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
,
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
)
->
tf
.
keras
.
Model
:
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
**
kwargs
)
->
Union
[
None
,
tf
.
keras
.
Model
,
tf
.
keras
.
layers
.
Layer
]:
"""Builds decoder from a config.
"""Builds decoder from a config.
Args:
Args:
input_specs: `dict` input specifications. A dictionary consists of
input_specs:
A
`dict`
of
input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
model_config: A `OneOfConfig` of model config.
l2_regularizer: tf.keras.regularizers.Regularizer instance. Default to None.
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None.
**kwargs: Additional keyword args to be passed to decoder builder.
Returns:
Returns:
A
tf.keras.Model
instance of the decoder.
A
n
instance of the decoder.
"""
"""
decoder_type
=
model_config
.
decoder
.
type
decoder_builder
=
registry
.
lookup
(
_REGISTERED_DECODER_CLS
,
decoder_cfg
=
model_config
.
decoder
.
get
()
model_config
.
decoder
.
type
)
norm_activation_config
=
model_config
.
norm_activation
return
decoder_builder
(
if
decoder_type
==
'identity'
:
decoder
=
None
elif
decoder_type
==
'unet_3d_decoder'
:
decoder
=
decoders
.
UNet3DDecoder
(
model_id
=
decoder_cfg
.
model_id
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
pool_size
=
decoder_cfg
.
pool_size
,
model_config
=
model_config
,
kernel_regularizer
=
l2_regularizer
,
l2_regularizer
=
l2_regularizer
,
activation
=
norm_activation_config
.
activation
,
**
kwargs
)
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_batch_normalization
=
decoder_cfg
.
use_batch_normalization
,
use_deconvolution
=
decoder_cfg
.
use_deconvolution
)
else
:
raise
ValueError
(
'Decoder {!r} not implement'
.
format
(
decoder_type
))
return
decoder
official/vision/beta/projects/volumetric_models/modeling/decoders/factory_test.py
0 → 100644
View file @
dbaec326
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for factory functions."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.vision.beta.projects.volumetric_models.configs
import
decoders
as
decoders_cfg
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
semantic_segmentation_3d_exp
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.modeling.decoders
import
factory
class
FactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
model_id
=
[
2
,
3
],))
def
test_unet_3d_decoder_creation
(
self
,
model_id
):
"""Test creation of UNet 3D decoder."""
# Create test input for decoders based on input model_id.
input_specs
=
{}
for
level
in
range
(
model_id
):
input_specs
[
str
(
level
+
1
)]
=
tf
.
TensorShape
(
[
1
,
128
//
(
2
**
level
),
128
//
(
2
**
level
),
128
//
(
2
**
level
),
1
])
network
=
decoders
.
UNet3DDecoder
(
model_id
=
model_id
,
input_specs
=
input_specs
,
use_sync_bn
=
True
,
use_batch_normalization
=
True
,
use_deconvolution
=
True
)
model_config
=
semantic_segmentation_3d_exp
.
SemanticSegmentationModel3D
()
model_config
.
num_classes
=
2
model_config
.
num_channels
=
1
model_config
.
input_size
=
[
None
,
None
,
None
]
model_config
.
decoder
=
decoders_cfg
.
Decoder
(
type
=
'unet_3d_decoder'
,
unet_3d_decoder
=
decoders_cfg
.
UNet3DDecoder
(
model_id
=
model_id
))
factory_network
=
factory
.
build_decoder
(
input_specs
=
input_specs
,
model_config
=
model_config
)
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
print
(
network_config
)
print
(
factory_network_config
)
self
.
assertEqual
(
network_config
,
factory_network_config
)
def
test_identity_creation
(
self
):
"""Test creation of identity decoder."""
model_config
=
semantic_segmentation_3d_exp
.
SemanticSegmentationModel3D
()
model_config
.
num_classes
=
2
model_config
.
num_channels
=
3
model_config
.
input_size
=
[
None
,
None
,
None
]
model_config
.
decoder
=
decoders_cfg
.
Decoder
(
type
=
'identity'
,
identity
=
decoders_cfg
.
Identity
())
factory_network
=
factory
.
build_decoder
(
input_specs
=
None
,
model_config
=
model_config
)
self
.
assertIsNone
(
factory_network
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/volumetric_models/modeling/decoders/unet_3d_decoder.py
View file @
dbaec326
...
@@ -19,10 +19,13 @@ Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
...
@@ -19,10 +19,13 @@ Ronneberger. 3D U-Net: Learning Dense Volumetric Segmentation from Sparse
Annotation. arXiv:1606.06650.
Annotation. arXiv:1606.06650.
"""
"""
from
typing
import
Any
,
Sequence
,
Dict
,
Mapping
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Sequence
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision.beta.projects.volumetric_models.modeling
import
nn_blocks_3d
from
official.vision.beta.projects.volumetric_models.modeling
import
nn_blocks_3d
from
official.vision.beta.projects.volumetric_models.modeling.decoders
import
factory
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
...
@@ -152,3 +155,39 @@ class UNet3DDecoder(tf.keras.Model):
...
@@ -152,3 +155,39 @@ class UNet3DDecoder(tf.keras.Model):
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
"""A dict of {level: TensorShape} pairs for the model output."""
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
return
self
.
_output_specs
@
factory
.
register_decoder_builder
(
'unet_3d_decoder'
)
def
build_unet_3d_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds UNet3D decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the UNet3D decoder.
"""
decoder_type
=
model_config
.
decoder
.
type
decoder_cfg
=
model_config
.
decoder
.
get
()
assert
decoder_type
==
'unet_3d_decoder'
,
(
f
'Inconsistent decoder type '
f
'
{
decoder_type
}
'
)
norm_activation_config
=
model_config
.
norm_activation
return
UNet3DDecoder
(
model_id
=
decoder_cfg
.
model_id
,
input_specs
=
input_specs
,
pool_size
=
decoder_cfg
.
pool_size
,
kernel_regularizer
=
l2_regularizer
,
activation
=
norm_activation_config
.
activation
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_batch_normalization
=
decoder_cfg
.
use_batch_normalization
,
use_deconvolution
=
decoder_cfg
.
use_deconvolution
)
official/vision/beta/projects/volumetric_models/modeling/factory_test.py
View file @
dbaec326
...
@@ -17,9 +17,11 @@
...
@@ -17,9 +17,11 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
SegmentationModelBuilderTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/vision/beta/projects/volumetric_models/registry_imports.py
View file @
dbaec326
...
@@ -17,4 +17,5 @@
...
@@ -17,4 +17,5 @@
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.common
import
registry_imports
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d.py
View file @
dbaec326
...
@@ -18,8 +18,10 @@ from typing import Mapping
...
@@ -18,8 +18,10 @@ from typing import Mapping
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling
import
factory
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
from
official.vision.beta.serving
import
export_base
from
official.vision.beta.serving
import
export_base
...
...
official/vision/beta/projects/volumetric_models/serving/semantic_segmentation_3d_test.py
View file @
dbaec326
...
@@ -20,9 +20,11 @@ from absl.testing import parameterized
...
@@ -20,9 +20,11 @@ from absl.testing import parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
# pylint: disable=unused-import
from
official.vision.beta.projects.volumetric_models.configs
import
semantic_segmentation_3d
as
exp_cfg
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
# pylint: disable=unused-import
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.serving
import
semantic_segmentation_3d
from
official.vision.beta.projects.volumetric_models.serving
import
semantic_segmentation_3d
...
...
official/vision/beta/projects/volumetric_models/tasks/semantic_segmentation_3d_test.py
View file @
dbaec326
...
@@ -28,7 +28,8 @@ from official.core import exp_factory
...
@@ -28,7 +28,8 @@ from official.core import exp_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.dataloaders
import
tfexample_utils
from
official.vision.beta.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.vision.beta.projects.volumetric_models.evaluation
import
segmentation_metrics
from
official.vision.beta.projects.volumetric_models.modeling.backbones
import
unet_3d
from
official.vision.beta.projects.volumetric_models.modeling
import
backbones
from
official.vision.beta.projects.volumetric_models.modeling
import
decoders
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
as
img_seg_task
from
official.vision.beta.projects.volumetric_models.tasks
import
semantic_segmentation_3d
as
img_seg_task
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment