Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a708019f
Commit
a708019f
authored
Oct 13, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Oct 13, 2020
Browse files
Internal change
PiperOrigin-RevId: 336939703
parent
cd39fd58
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
360 additions
and
0 deletions
+360
-0
official/vision/beta/modeling/heads/segmentation_heads.py
official/vision/beta/modeling/heads/segmentation_heads.py
+151
-0
official/vision/beta/modeling/heads/segmentation_heads_test.py
...ial/vision/beta/modeling/heads/segmentation_heads_test.py
+49
-0
official/vision/beta/modeling/segmentation_model.py
official/vision/beta/modeling/segmentation_model.py
+75
-0
official/vision/beta/modeling/segmentation_model_test.py
official/vision/beta/modeling/segmentation_model_test.py
+85
-0
No files found.
official/vision/beta/modeling/heads/segmentation_heads.py
0 → 100644
View file @
a708019f
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Segmentation heads."""
# Import libraries
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.vision.beta.ops
import
spatial_transform_ops
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SegmentationHead
(
tf
.
keras
.
layers
.
Layer
):
"""Segmentation head."""
def
__init__
(
self
,
num_classes
,
level
,
num_convs
=
2
,
num_filters
=
256
,
upsample_factor
=
1
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
"""Initialize params to build segmentation head.
Args:
num_classes: `int` number of mask classification categories. The number of
classes does not include background class.
level: `int` or `str`, level to use to build segmentation head.
num_convs: `int` number of stacked convolution before the last prediction
layer.
num_filters: `int` number to specify the number of filters used.
Default is 256.
upsample_factor: `int` number to specify the upsampling factor to generate
finer mask. Default 1 means no upsampling is applied.
activation: `string`, indicating which activation is used, e.g. 'relu',
'swish', etc.
use_sync_bn: `bool`, whether to use synchronized batch normalization
across different replicas.
norm_momentum: `float`, the momentum parameter of the normalization
layers.
norm_epsilon: `float`, the epsilon parameter of the normalization layers.
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
**kwargs: other keyword arguments passed to Layer.
"""
super
(
SegmentationHead
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'num_classes'
:
num_classes
,
'level'
:
level
,
'num_convs'
:
num_convs
,
'num_filters'
:
num_filters
,
'upsample_factor'
:
upsample_factor
,
'activation'
:
activation
,
'use_sync_bn'
:
use_sync_bn
,
'norm_momentum'
:
norm_momentum
,
'norm_epsilon'
:
norm_epsilon
,
'kernel_regularizer'
:
kernel_regularizer
,
'bias_regularizer'
:
bias_regularizer
,
}
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
"""Creates the variables of the segmentation head."""
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
'kernel_size'
:
3
,
'padding'
:
'same'
,
'bias_initializer'
:
tf
.
zeros_initializer
(),
'kernel_initializer'
:
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.01
),
'kernel_regularizer'
:
self
.
_config_dict
[
'kernel_regularizer'
],
}
bn_op
=
(
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
if
self
.
_config_dict
[
'use_sync_bn'
]
else
tf
.
keras
.
layers
.
BatchNormalization
)
bn_kwargs
=
{
'axis'
:
self
.
_bn_axis
,
'momentum'
:
self
.
_config_dict
[
'norm_momentum'
],
'epsilon'
:
self
.
_config_dict
[
'norm_epsilon'
],
}
# Segmentation head layers.
self
.
_convs
=
[]
self
.
_norms
=
[]
for
i
in
range
(
self
.
_config_dict
[
'num_convs'
]):
conv_name
=
'segmentation_head_conv_{}'
.
format
(
i
)
self
.
_convs
.
append
(
conv_op
(
name
=
conv_name
,
filters
=
self
.
_config_dict
[
'num_filters'
],
**
conv_kwargs
))
norm_name
=
'segmentation_head_norm_{}'
.
format
(
i
)
self
.
_norms
.
append
(
bn_op
(
name
=
norm_name
,
**
bn_kwargs
))
self
.
_classifier
=
conv_op
(
name
=
'segmentation_output'
,
filters
=
self
.
_config_dict
[
'num_classes'
],
**
conv_kwargs
)
super
(
SegmentationHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
features
):
"""Forward pass of the segmentation head.
Args:
features: a dict of tensors
- key: `str`, the level of the multilevel features.
- values: `Tensor`, the feature map tensors, whose shape is
[batch, height_l, width_l, channels].
Returns:
segmentation prediction mask: `Tensor`, the segmentation mask scores
predicted from input feature.
"""
x
=
features
[
str
(
self
.
_config_dict
[
'level'
])]
for
conv
,
norm
in
zip
(
self
.
_convs
,
self
.
_norms
):
x
=
conv
(
x
)
x
=
norm
(
x
)
x
=
self
.
_activation
(
x
)
x
=
spatial_transform_ops
.
nearest_upsampling
(
x
,
scale
=
self
.
_config_dict
[
'upsample_factor'
])
return
self
.
_classifier
(
x
)
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
**
config
)
official/vision/beta/modeling/heads/segmentation_heads_test.py
0 → 100644
View file @
a708019f
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for segmentation_heads.py."""
# Import libraries
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling.heads
import
segmentation_heads
class
SegmentationHeadTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
3
),
(
4
),
)
def
test_forward
(
self
,
level
):
head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
10
,
level
=
level
)
features
=
{
'3'
:
np
.
random
.
rand
(
2
,
128
,
128
,
16
),
'4'
:
np
.
random
.
rand
(
2
,
64
,
64
,
16
),
}
logits
=
head
(
features
)
self
.
assertAllEqual
(
logits
.
numpy
().
shape
,
[
2
,
features
[
str
(
level
)].
shape
[
1
],
features
[
str
(
level
)].
shape
[
2
],
10
])
def
test_serialize_deserialize
(
self
):
head
=
segmentation_heads
.
SegmentationHead
(
num_classes
=
10
,
level
=
3
)
config
=
head
.
get_config
()
new_head
=
segmentation_heads
.
SegmentationHead
.
from_config
(
config
)
self
.
assertAllEqual
(
head
.
get_config
(),
new_head
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/modeling/segmentation_model.py
0 → 100644
View file @
a708019f
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build segmentation models."""
# Import libraries
import
tensorflow
as
tf
layers
=
tf
.
keras
.
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Vision'
)
class
SegmentationModel
(
tf
.
keras
.
Model
):
"""A Segmentation class model.
Input images are passed through backbone first. Decoder network is then
applied, and finally, segmentation head is applied on the output of the
decoder network. Layers such as ASPP should be part of decoder.
"""
def
__init__
(
self
,
backbone
,
decoder
,
head
,
**
kwargs
):
"""Segmentation initialization function.
Args:
backbone: a backbone network.
decoder: a decoder network. E.g. FPN.
head: segmentation head.
**kwargs: keyword arguments to be passed.
"""
super
(
SegmentationModel
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'decoder'
:
decoder
,
'head'
:
head
,
}
self
.
backbone
=
backbone
self
.
decoder
=
decoder
self
.
head
=
head
def
call
(
self
,
inputs
,
training
=
None
):
features
=
self
.
backbone
(
inputs
)
if
self
.
decoder
:
features
=
self
.
decoder
(
features
)
return
self
.
head
(
features
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
backbone
=
self
.
backbone
,
head
=
self
.
head
)
if
self
.
decoder
is
not
None
:
items
.
update
(
decoder
=
self
.
decoder
)
return
items
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/vision/beta/modeling/segmentation_model_test.py
0 → 100644
View file @
a708019f
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for segmentation network."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.modeling
import
segmentation_model
from
official.vision.beta.modeling.decoders
import
fpn
from
official.vision.beta.modeling.heads
import
segmentation_heads
class
SegmentationNetworkTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
128
,
2
),
(
128
,
3
),
(
128
,
4
),
(
256
,
2
),
(
256
,
3
),
(
256
,
4
),
)
def
test_segmentation_network_creation
(
self
,
input_size
,
level
):
"""Test for creation of a segmentation network."""
num_classes
=
10
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
backbones
.
ResNet
(
model_id
=
50
)
decoder
=
fpn
.
FPN
(
input_specs
=
backbone
.
output_specs
,
min_level
=
2
,
max_level
=
7
)
head
=
segmentation_heads
.
SegmentationHead
(
num_classes
,
level
=
level
)
model
=
segmentation_model
.
SegmentationModel
(
backbone
=
backbone
,
decoder
=
decoder
,
head
=
head
)
logits
=
model
(
inputs
)
self
.
assertAllEqual
(
[
2
,
input_size
//
(
2
**
level
),
input_size
//
(
2
**
level
),
num_classes
],
logits
.
numpy
().
shape
)
def
test_serialize_deserialize
(
self
):
"""Validate the network can be serialized and deserialized."""
num_classes
=
3
backbone
=
backbones
.
ResNet
(
model_id
=
50
)
decoder
=
fpn
.
FPN
(
input_specs
=
backbone
.
output_specs
,
min_level
=
3
,
max_level
=
7
)
head
=
segmentation_heads
.
SegmentationHead
(
num_classes
,
level
=
3
)
model
=
segmentation_model
.
SegmentationModel
(
backbone
=
backbone
,
decoder
=
decoder
,
head
=
head
)
config
=
model
.
get_config
()
new_model
=
segmentation_model
.
SegmentationModel
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_model
.
to_json
()
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
model
.
get_config
(),
new_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment