Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
601953a4
Commit
601953a4
authored
Jul 23, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Jul 23, 2021
Browse files
Make register for decoders so custom decoders can be registered.
PiperOrigin-RevId: 386565375
parent
6803e44e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
396 additions
and
62 deletions
+396
-62
official/vision/beta/modeling/decoders/aspp.py
official/vision/beta/modeling/decoders/aspp.py
+47
-1
official/vision/beta/modeling/decoders/factory.py
official/vision/beta/modeling/decoders/factory.py
+100
-56
official/vision/beta/modeling/decoders/factory_test.py
official/vision/beta/modeling/decoders/factory_test.py
+155
-0
official/vision/beta/modeling/decoders/fpn.py
official/vision/beta/modeling/decoders/fpn.py
+43
-0
official/vision/beta/modeling/decoders/nasfpn.py
official/vision/beta/modeling/decoders/nasfpn.py
+47
-1
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+4
-4
No files found.
official/vision/beta/modeling/decoders/aspp.py
View file @
601953a4
...
@@ -13,12 +13,15 @@
...
@@ -13,12 +13,15 @@
# limitations under the License.
# limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from
typing
import
Any
,
List
,
Optional
,
Mapping
from
typing
import
Any
,
List
,
Mapping
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision
import
keras_cv
from
official.vision
import
keras_cv
from
official.vision.beta.modeling.decoders
import
factory
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
...
@@ -128,3 +131,46 @@ class ASPP(tf.keras.layers.Layer):
...
@@ -128,3 +131,46 @@ class ASPP(tf.keras.layers.Layer):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
return
cls
(
**
config
)
@
factory
.
register_decoder_builder
(
'aspp'
)
def
build_aspp_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds ASPP decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone. Note this is for consistent
interface, and is not used by ASPP decoder.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the ASPP decoder.
Raises:
ValueError: If the model_config.decoder.type is not `aspp`.
"""
del
input_specs
# input_specs is not used by ASPP decoder.
decoder_type
=
model_config
.
decoder
.
type
decoder_cfg
=
model_config
.
decoder
.
get
()
if
decoder_type
!=
'aspp'
:
raise
ValueError
(
f
'Inconsistent decoder type
{
decoder_type
}
. '
'Need to be `aspp`.'
)
norm_activation_config
=
model_config
.
norm_activation
return
ASPP
(
level
=
decoder_cfg
.
level
,
dilation_rates
=
decoder_cfg
.
dilation_rates
,
num_filters
=
decoder_cfg
.
num_filters
,
pool_kernel_size
=
decoder_cfg
.
pool_kernel_size
,
dropout_rate
=
decoder_cfg
.
dropout_rate
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
activation
=
norm_activation_config
.
activation
,
kernel_regularizer
=
l2_regularizer
)
official/vision/beta/modeling/decoders/factory.py
View file @
601953a4
...
@@ -12,80 +12,124 @@
...
@@ -12,80 +12,124 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Decoder registers and factory method.
"""Contains the factory method to create decoders."""
from
typing
import
Mapping
,
Optional
One can register a new decoder model by the following two steps:
1 Import the factory and register the build in the decoder file.
2 Import the decoder class and add a build in __init__.py.
```
# my_decoder.py
from modeling.decoders import factory
class MyDecoder():
...
@factory.register_decoder_builder('my_decoder')
def build_my_decoder():
return MyDecoder()
# decoders/__init__.py adds import
from modeling.decoders.my_decoder import MyDecoder
```
If one wants the MyDecoder class to be used only by those binary
then don't imported the decoder module in decoders/__init__.py, but import it
in place that uses it.
"""
from
typing
import
Any
,
Callable
,
Mapping
,
Optional
,
Union
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
registry
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling
import
decoders
_REGISTERED_DECODER_CLS
=
{}
def
register_decoder_builder
(
key
:
str
)
->
Callable
[...,
Any
]:
"""Decorates a builder of decoder class.
The builder should be a Callable (a class or a function).
This decorator supports registration of decoder builder as follows:
```
class MyDecoder(tf.keras.Model):
pass
@register_decoder_builder('mydecoder')
def builder(input_specs, config, l2_reg):
return MyDecoder(...)
# Builds a MyDecoder object.
my_decoder = build_decoder_3d(input_specs, config, l2_reg)
```
Args:
key: A `str` of key to look up the builder.
Returns:
A callable for using as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return
registry
.
register
(
_REGISTERED_DECODER_CLS
,
key
)
@
register_decoder_builder
(
'identity'
)
def
build_identity
(
input_specs
:
Optional
[
Mapping
[
str
,
tf
.
TensorShape
]]
=
None
,
model_config
:
Optional
[
hyperparams
.
Config
]
=
None
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
None
:
"""Builds identity decoder from a config.
All the input arguments are not used by identity decoder but kept here to
ensure the interface is consistent.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A `OneOfConfig` of model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None.
Returns:
An instance of the identity decoder.
"""
del
input_specs
,
model_config
,
l2_regularizer
# Unused by identity decoder.
def
build_decoder
(
def
build_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
:
hyperparams
.
Config
,
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
l2_regularizer
:
tf
.
keras
.
regularizers
.
Regularizer
=
None
,
)
->
tf
.
keras
.
Model
:
**
kwargs
)
->
Union
[
None
,
tf
.
keras
.
Model
,
tf
.
keras
.
layers
.
Layer
]
:
"""Builds decoder from a config.
"""Builds decoder from a config.
A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not
None, the decoder will take features from the backbone as input and generate
decoded feature maps. If it is None, such as an identity decoder, the decoder
is skipped and features from the backbone are regarded as model output.
Args:
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
{level: TensorShape} from a backbone.
model_config: A OneOfConfig
. M
odel config.
model_config: A
`
OneOfConfig
` of m
odel config.
l2_regularizer: A `tf.keras.regularizers.Regularizer`
instance
. Default to
l2_regularizer: A `tf.keras.regularizers.Regularizer`
object
. Default to
None.
None.
**kwargs: Additional keyword args to be passed to decoder builder.
Returns:
Returns:
A
`tf.keras.Model`
instance of the decoder.
A
n
instance of the decoder.
"""
"""
decoder_type
=
model_config
.
decoder
.
type
decoder_builder
=
registry
.
lookup
(
_REGISTERED_DECODER_CLS
,
decoder_cfg
=
model_config
.
decoder
.
get
()
model_config
.
decoder
.
type
)
norm_activation_config
=
model_config
.
norm_activation
return
decoder_builder
(
if
decoder_type
==
'identity'
:
input_specs
=
input_specs
,
decoder
=
None
model_config
=
model_config
,
elif
decoder_type
==
'fpn'
:
l2_regularizer
=
l2_regularizer
,
decoder
=
decoders
.
FPN
(
**
kwargs
)
input_specs
=
input_specs
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
num_filters
=
decoder_cfg
.
num_filters
,
use_separable_conv
=
decoder_cfg
.
use_separable_conv
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
elif
decoder_type
==
'nasfpn'
:
decoder
=
decoders
.
NASFPN
(
input_specs
=
input_specs
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
num_filters
=
decoder_cfg
.
num_filters
,
num_repeats
=
decoder_cfg
.
num_repeats
,
use_separable_conv
=
decoder_cfg
.
use_separable_conv
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
elif
decoder_type
==
'aspp'
:
decoder
=
decoders
.
ASPP
(
level
=
decoder_cfg
.
level
,
dilation_rates
=
decoder_cfg
.
dilation_rates
,
num_filters
=
decoder_cfg
.
num_filters
,
pool_kernel_size
=
decoder_cfg
.
pool_kernel_size
,
dropout_rate
=
decoder_cfg
.
dropout_rate
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
activation
=
norm_activation_config
.
activation
,
kernel_regularizer
=
l2_regularizer
)
else
:
raise
ValueError
(
'Decoder {!r} not implement'
.
format
(
decoder_type
))
return
decoder
official/vision/beta/modeling/decoders/factory_test.py
0 → 100644
View file @
601953a4
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for decoder factory functions."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
official.vision.beta
import
configs
from
official.vision.beta.configs
import
decoders
as
decoders_cfg
from
official.vision.beta.modeling
import
decoders
from
official.vision.beta.modeling.decoders
import
factory
class
FactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
num_filters
=
[
128
,
256
],
use_separable_conv
=
[
True
,
False
]))
def
test_fpn_decoder_creation
(
self
,
num_filters
,
use_separable_conv
):
"""Test creation of FPN decoder."""
min_level
=
3
max_level
=
7
input_specs
=
{}
for
level
in
range
(
min_level
,
max_level
):
input_specs
[
str
(
level
)]
=
tf
.
TensorShape
(
[
1
,
128
//
(
2
**
level
),
128
//
(
2
**
level
),
3
])
network
=
decoders
.
FPN
(
input_specs
=
input_specs
,
num_filters
=
num_filters
,
use_separable_conv
=
use_separable_conv
,
use_sync_bn
=
True
)
model_config
=
configs
.
retinanet
.
RetinaNet
()
model_config
.
min_level
=
min_level
model_config
.
max_level
=
max_level
model_config
.
num_classes
=
10
model_config
.
input_size
=
[
None
,
None
,
3
]
model_config
.
decoder
=
decoders_cfg
.
Decoder
(
type
=
'fpn'
,
fpn
=
decoders_cfg
.
FPN
(
num_filters
=
num_filters
,
use_separable_conv
=
use_separable_conv
))
factory_network
=
factory
.
build_decoder
(
input_specs
=
input_specs
,
model_config
=
model_config
)
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
self
.
assertEqual
(
network_config
,
factory_network_config
)
@
combinations
.
generate
(
combinations
.
combine
(
num_filters
=
[
128
,
256
],
num_repeats
=
[
3
,
5
],
use_separable_conv
=
[
True
,
False
]))
def
test_nasfpn_decoder_creation
(
self
,
num_filters
,
num_repeats
,
use_separable_conv
):
"""Test creation of NASFPN decoder."""
min_level
=
3
max_level
=
7
input_specs
=
{}
for
level
in
range
(
min_level
,
max_level
):
input_specs
[
str
(
level
)]
=
tf
.
TensorShape
(
[
1
,
128
//
(
2
**
level
),
128
//
(
2
**
level
),
3
])
network
=
decoders
.
NASFPN
(
input_specs
=
input_specs
,
num_filters
=
num_filters
,
num_repeats
=
num_repeats
,
use_separable_conv
=
use_separable_conv
,
use_sync_bn
=
True
)
model_config
=
configs
.
retinanet
.
RetinaNet
()
model_config
.
min_level
=
min_level
model_config
.
max_level
=
max_level
model_config
.
num_classes
=
10
model_config
.
input_size
=
[
None
,
None
,
3
]
model_config
.
decoder
=
decoders_cfg
.
Decoder
(
type
=
'nasfpn'
,
nasfpn
=
decoders_cfg
.
NASFPN
(
num_filters
=
num_filters
,
num_repeats
=
num_repeats
,
use_separable_conv
=
use_separable_conv
))
factory_network
=
factory
.
build_decoder
(
input_specs
=
input_specs
,
model_config
=
model_config
)
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
self
.
assertEqual
(
network_config
,
factory_network_config
)
@
combinations
.
generate
(
combinations
.
combine
(
level
=
[
3
,
4
],
dilation_rates
=
[[
6
,
12
,
18
],
[
6
,
12
]],
num_filters
=
[
128
,
256
]))
def
test_aspp_decoder_creation
(
self
,
level
,
dilation_rates
,
num_filters
):
"""Test creation of ASPP decoder."""
input_specs
=
{
'1'
:
tf
.
TensorShape
([
1
,
128
,
128
,
3
])}
network
=
decoders
.
ASPP
(
level
=
level
,
dilation_rates
=
dilation_rates
,
num_filters
=
num_filters
,
use_sync_bn
=
True
)
model_config
=
configs
.
semantic_segmentation
.
SemanticSegmentationModel
()
model_config
.
num_classes
=
10
model_config
.
input_size
=
[
None
,
None
,
3
]
model_config
.
decoder
=
decoders_cfg
.
Decoder
(
type
=
'aspp'
,
aspp
=
decoders_cfg
.
ASPP
(
level
=
level
,
dilation_rates
=
dilation_rates
,
num_filters
=
num_filters
))
factory_network
=
factory
.
build_decoder
(
input_specs
=
input_specs
,
model_config
=
model_config
)
network_config
=
network
.
get_config
()
factory_network_config
=
factory_network
.
get_config
()
self
.
assertEqual
(
network_config
,
factory_network_config
)
def
test_identity_decoder_creation
(
self
):
"""Test creation of identity decoder."""
model_config
=
configs
.
retinanet
.
RetinaNet
()
model_config
.
num_classes
=
2
model_config
.
input_size
=
[
None
,
None
,
3
]
model_config
.
decoder
=
decoders_cfg
.
Decoder
(
type
=
'identity'
,
identity
=
decoders_cfg
.
Identity
())
factory_network
=
factory
.
build_decoder
(
input_specs
=
None
,
model_config
=
model_config
)
self
.
assertIsNone
(
factory_network
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/modeling/decoders/fpn.py
View file @
601953a4
...
@@ -16,9 +16,12 @@
...
@@ -16,9 +16,12 @@
from
typing
import
Any
,
Mapping
,
Optional
from
typing
import
Any
,
Mapping
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling.decoders
import
factory
from
official.vision.beta.ops
import
spatial_transform_ops
from
official.vision.beta.ops
import
spatial_transform_ops
...
@@ -187,3 +190,43 @@ class FPN(tf.keras.Model):
...
@@ -187,3 +190,43 @@ class FPN(tf.keras.Model):
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
"""A dict of {level: TensorShape} pairs for the model output."""
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
return
self
.
_output_specs
@
factory
.
register_decoder_builder
(
'fpn'
)
def
build_fpn_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds FPN decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the FPN decoder.
Raises:
ValueError: If the model_config.decoder.type is not `fpn`.
"""
decoder_type
=
model_config
.
decoder
.
type
decoder_cfg
=
model_config
.
decoder
.
get
()
if
decoder_type
!=
'fpn'
:
raise
ValueError
(
f
'Inconsistent decoder type
{
decoder_type
}
. '
'Need to be `fpn`.'
)
norm_activation_config
=
model_config
.
norm_activation
return
FPN
(
input_specs
=
input_specs
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
num_filters
=
decoder_cfg
.
num_filters
,
use_separable_conv
=
decoder_cfg
.
use_separable_conv
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
official/vision/beta/modeling/decoders/nasfpn.py
View file @
601953a4
...
@@ -13,12 +13,16 @@
...
@@ -13,12 +13,16 @@
# limitations under the License.
# limitations under the License.
"""Contains definitions of NAS-FPN."""
"""Contains definitions of NAS-FPN."""
from
typing
import
Any
,
Mapping
,
List
,
Tuple
,
Optional
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Tuple
# Import libraries
# Import libraries
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.vision.beta.modeling.decoders
import
factory
from
official.vision.beta.ops
import
spatial_transform_ops
from
official.vision.beta.ops
import
spatial_transform_ops
...
@@ -316,3 +320,45 @@ class NASFPN(tf.keras.Model):
...
@@ -316,3 +320,45 @@ class NASFPN(tf.keras.Model):
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
def
output_specs
(
self
)
->
Mapping
[
str
,
tf
.
TensorShape
]:
"""A dict of {level: TensorShape} pairs for the model output."""
"""A dict of {level: TensorShape} pairs for the model output."""
return
self
.
_output_specs
return
self
.
_output_specs
@
factory
.
register_decoder_builder
(
'nasfpn'
)
def
build_nasfpn_decoder
(
input_specs
:
Mapping
[
str
,
tf
.
TensorShape
],
model_config
:
hyperparams
.
Config
,
l2_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
)
->
tf
.
keras
.
Model
:
"""Builds NASFPN decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the NASFPN decoder.
Raises:
ValueError: If the model_config.decoder.type is not `nasfpn`.
"""
decoder_type
=
model_config
.
decoder
.
type
decoder_cfg
=
model_config
.
decoder
.
get
()
if
decoder_type
!=
'nasfpn'
:
raise
ValueError
(
f
'Inconsistent decoder type
{
decoder_type
}
. '
'Need to be `nasfpn`.'
)
norm_activation_config
=
model_config
.
norm_activation
return
NASFPN
(
input_specs
=
input_specs
,
min_level
=
model_config
.
min_level
,
max_level
=
model_config
.
max_level
,
num_filters
=
decoder_cfg
.
num_filters
,
num_repeats
=
decoder_cfg
.
num_repeats
,
use_separable_conv
=
decoder_cfg
.
use_separable_conv
,
activation
=
norm_activation_config
.
activation
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
,
kernel_regularizer
=
l2_regularizer
)
official/vision/beta/modeling/factory.py
View file @
601953a4
...
@@ -24,10 +24,10 @@ from official.vision.beta.configs import retinanet as retinanet_cfg
...
@@ -24,10 +24,10 @@ from official.vision.beta.configs import retinanet as retinanet_cfg
from
official.vision.beta.configs
import
semantic_segmentation
as
segmentation_cfg
from
official.vision.beta.configs
import
semantic_segmentation
as
segmentation_cfg
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
classification_model
from
official.vision.beta.modeling
import
classification_model
from
official.vision.beta.modeling
import
decoders
from
official.vision.beta.modeling
import
maskrcnn_model
from
official.vision.beta.modeling
import
maskrcnn_model
from
official.vision.beta.modeling
import
retinanet_model
from
official.vision.beta.modeling
import
retinanet_model
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling.decoders
import
factory
as
decoder_factory
from
official.vision.beta.modeling.heads
import
dense_prediction_heads
from
official.vision.beta.modeling.heads
import
dense_prediction_heads
from
official.vision.beta.modeling.heads
import
instance_heads
from
official.vision.beta.modeling.heads
import
instance_heads
from
official.vision.beta.modeling.heads
import
segmentation_heads
from
official.vision.beta.modeling.heads
import
segmentation_heads
...
@@ -78,7 +78,7 @@ def build_maskrcnn(
...
@@ -78,7 +78,7 @@ def build_maskrcnn(
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
backbone
(
tf
.
keras
.
Input
(
input_specs
.
shape
[
1
:]))
backbone
(
tf
.
keras
.
Input
(
input_specs
.
shape
[
1
:]))
decoder
=
decoder
_
factory
.
build_decoder
(
decoder
=
decoder
s
.
factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
...
@@ -253,7 +253,7 @@ def build_retinanet(
...
@@ -253,7 +253,7 @@ def build_retinanet(
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
backbone
(
tf
.
keras
.
Input
(
input_specs
.
shape
[
1
:]))
backbone
(
tf
.
keras
.
Input
(
input_specs
.
shape
[
1
:]))
decoder
=
decoder
_
factory
.
build_decoder
(
decoder
=
decoder
s
.
factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
...
@@ -313,7 +313,7 @@ def build_segmentation_model(
...
@@ -313,7 +313,7 @@ def build_segmentation_model(
norm_activation_config
=
norm_activation_config
,
norm_activation_config
=
norm_activation_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
decoder
=
decoder
_
factory
.
build_decoder
(
decoder
=
decoder
s
.
factory
.
build_decoder
(
input_specs
=
backbone
.
output_specs
,
input_specs
=
backbone
.
output_specs
,
model_config
=
model_config
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
l2_regularizer
=
l2_regularizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment