Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b0ccdb11
Commit
b0ccdb11
authored
Sep 28, 2020
by
Shixin Luo
Browse files
resolve conflict with master
parents
e61588cd
1611a8c5
Changes
210
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
549 additions
and
40 deletions
+549
-40
research/object_detection/models/center_net_hourglass_feature_extractor_tf2_test.py
...models/center_net_hourglass_feature_extractor_tf2_test.py
+2
-1
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
...ction/models/center_net_mobilenet_v2_feature_extractor.py
+0
-2
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
...n/models/center_net_mobilenet_v2_fpn_feature_extractor.py
+142
-0
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
...center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
+46
-0
research/object_detection/models/keras_models/hourglass_network.py
...object_detection/models/keras_models/hourglass_network.py
+273
-32
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
...tection/models/keras_models/hourglass_network_tf2_test.py
+60
-2
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+23
-0
research/slim/datasets/imagenet.py
research/slim/datasets/imagenet.py
+1
-1
research/slim/nets/mobilenet/mobilenet_example.ipynb
research/slim/nets/mobilenet/mobilenet_example.ipynb
+1
-1
research/slim/nets/mobilenet_v1.py
research/slim/nets/mobilenet_v1.py
+1
-1
No files found.
research/object_detection/models/center_net_hourglass_feature_extractor_tf2_test.py
View file @
b0ccdb11
...
...
@@ -30,7 +30,8 @@ class CenterNetHourglassFeatureExtractorTest(test_case.TestCase):
net
=
hourglass_network
.
HourglassNetwork
(
num_stages
=
4
,
blocks_per_stage
=
[
2
,
3
,
4
,
5
,
6
],
channel_dims
=
[
4
,
6
,
8
,
10
,
12
,
14
],
num_hourglasses
=
2
)
input_channel_dims
=
4
,
channel_dims_per_stage
=
[
6
,
8
,
10
,
12
,
14
],
num_hourglasses
=
2
)
model
=
hourglass
.
CenterNetHourglassFeatureExtractor
(
net
)
def
graph_fn
():
...
...
research/object_detection/models/center_net_mobilenet_v2_feature_extractor.py
View file @
b0ccdb11
...
...
@@ -53,8 +53,6 @@ class CenterNetMobileNetV2FeatureExtractor(
output
=
self
.
_network
(
self
.
_network
.
input
)
# TODO(nkhadke): Try out MobileNet+FPN next (skip connections are cheap and
# should help with performance).
# MobileNet by itself transforms a 224x224x3 volume into a 7x7x1280, which
# leads to a stride of 32. We perform upsampling to get it to a target
# stride of 4.
...
...
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor.py
0 → 100644
View file @
b0ccdb11
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MobileNet V2[1] + FPN[2] feature extractor for CenterNet[3] meta architecture.
[1]: https://arxiv.org/abs/1801.04381
[2]: https://arxiv.org/abs/1612.03144.
[3]: https://arxiv.org/abs/1904.07850
"""
import
tensorflow.compat.v1
as
tf
from
object_detection.meta_architectures
import
center_net_meta_arch
from
object_detection.models.keras_models
import
mobilenet_v2
as
mobilenetv2
_MOBILENET_V2_FPN_SKIP_LAYERS
=
[
'block_2_add'
,
'block_5_add'
,
'block_9_add'
,
'out_relu'
]
class
CenterNetMobileNetV2FPNFeatureExtractor
(
center_net_meta_arch
.
CenterNetFeatureExtractor
):
"""The MobileNet V2 with FPN skip layers feature extractor for CenterNet."""
def
__init__
(
self
,
mobilenet_v2_net
,
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
):
"""Intializes the feature extractor.
Args:
mobilenet_v2_net: The underlying mobilenet_v2 network to use.
channel_means: A tuple of floats, denoting the mean of each channel
which will be subtracted from it.
channel_stds: A tuple of floats, denoting the standard deviation of each
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
"""
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
__init__
(
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
self
.
_network
=
mobilenet_v2_net
output
=
self
.
_network
(
self
.
_network
.
input
)
# Add pyramid feature network on every layer that has stride 2.
skip_outputs
=
[
self
.
_network
.
get_layer
(
skip_layer_name
).
output
for
skip_layer_name
in
_MOBILENET_V2_FPN_SKIP_LAYERS
]
self
.
_fpn_model
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_network
.
input
,
outputs
=
skip_outputs
)
fpn_outputs
=
self
.
_fpn_model
(
self
.
_network
.
input
)
# Construct the top-down feature maps -- we start with an output of
# 7x7x1280, which we continually upsample, apply a residual on and merge.
# This results in a 56x56x24 output volume.
top_layer
=
fpn_outputs
[
-
1
]
residual_op
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
)
top_down
=
residual_op
(
top_layer
)
num_filters_list
=
[
64
,
32
,
24
]
for
i
,
num_filters
in
enumerate
(
num_filters_list
):
level_ind
=
len
(
num_filters_list
)
-
1
-
i
# Upsample.
upsample_op
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
,
interpolation
=
'nearest'
)
top_down
=
upsample_op
(
top_down
)
# Residual (skip-connection) from bottom-up pathway.
residual_op
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
num_filters
,
kernel_size
=
1
,
strides
=
1
,
padding
=
'same'
)
residual
=
residual_op
(
fpn_outputs
[
level_ind
])
# Merge.
top_down
=
top_down
+
residual
next_num_filters
=
num_filters_list
[
i
+
1
]
if
i
+
1
<=
2
else
24
conv
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
next_num_filters
,
kernel_size
=
3
,
strides
=
1
,
padding
=
'same'
)
top_down
=
conv
(
top_down
)
top_down
=
tf
.
keras
.
layers
.
BatchNormalization
()(
top_down
)
top_down
=
tf
.
keras
.
layers
.
ReLU
()(
top_down
)
output
=
top_down
self
.
_network
=
tf
.
keras
.
models
.
Model
(
inputs
=
self
.
_network
.
input
,
outputs
=
output
)
def
preprocess
(
self
,
resized_inputs
):
resized_inputs
=
super
(
CenterNetMobileNetV2FPNFeatureExtractor
,
self
).
preprocess
(
resized_inputs
)
return
tf
.
keras
.
applications
.
mobilenet_v2
.
preprocess_input
(
resized_inputs
)
def
load_feature_extractor_weights
(
self
,
path
):
self
.
_network
.
load_weights
(
path
)
def
get_base_model
(
self
):
return
self
.
_network
def
call
(
self
,
inputs
):
return
[
self
.
_network
(
inputs
)]
@
property
def
out_stride
(
self
):
"""The stride in the output image of the network."""
return
4
@
property
def
num_feature_outputs
(
self
):
"""The number of feature outputs returned by the feature extractor."""
return
1
def
get_model
(
self
):
return
self
.
_network
def
mobilenet_v2_fpn
(
channel_means
,
channel_stds
,
bgr_ordering
):
"""The MobileNetV2+FPN backbone for CenterNet."""
# Set to is_training to True for now.
network
=
mobilenetv2
.
mobilenet_v2
(
True
,
include_top
=
False
)
return
CenterNetMobileNetV2FPNFeatureExtractor
(
network
,
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
research/object_detection/models/center_net_mobilenet_v2_fpn_feature_extractor_tf2_test.py
0 → 100644
View file @
b0ccdb11
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Testing mobilenet_v2+FPN feature extractor for CenterNet."""
import
unittest
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
from
object_detection.models
import
center_net_mobilenet_v2_fpn_feature_extractor
from
object_detection.models.keras_models
import
mobilenet_v2
from
object_detection.utils
import
test_case
from
object_detection.utils
import
tf_version
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMobileNetV2FPNFeatureExtractorTest
(
test_case
.
TestCase
):
def
test_center_net_mobilenet_v2_fpn_feature_extractor
(
self
):
net
=
mobilenet_v2
.
mobilenet_v2
(
True
,
include_top
=
False
)
model
=
center_net_mobilenet_v2_fpn_feature_extractor
.
CenterNetMobileNetV2FPNFeatureExtractor
(
net
)
def
graph_fn
():
img
=
np
.
zeros
((
8
,
224
,
224
,
3
),
dtype
=
np
.
float32
)
processed_img
=
model
.
preprocess
(
img
)
return
model
(
processed_img
)
outputs
=
self
.
execute
(
graph_fn
,
[])
self
.
assertEqual
(
outputs
.
shape
,
(
8
,
56
,
56
,
24
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/models/keras_models/hourglass_network.py
View file @
b0ccdb11
...
...
@@ -174,8 +174,38 @@ class InputDownsampleBlock(tf.keras.layers.Layer):
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
class
InputConvBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Block for the initial feature convolution.
This block is used in the hourglass network when we don't want to downsample
the input.
"""
def
__init__
(
self
,
out_channels_initial_conv
,
out_channels_residual_block
):
"""Initializes the downsample block.
Args:
out_channels_initial_conv: int, the desired number of output channels
in the initial conv layer.
out_channels_residual_block: int, the desired number of output channels
in the underlying residual block.
"""
super
(
InputConvBlock
,
self
).
__init__
()
self
.
conv_block
=
ConvolutionalBlock
(
kernel_size
=
3
,
out_channels
=
out_channels_initial_conv
,
stride
=
1
,
padding
=
'valid'
)
self
.
residual_block
=
ResidualBlock
(
out_channels
=
out_channels_residual_block
,
stride
=
1
,
skip_conv
=
True
)
def
call
(
self
,
inputs
):
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
def
_make_repeated_residual_blocks
(
out_channels
,
num_blocks
,
initial_stride
=
1
,
residual_channels
=
None
):
initial_stride
=
1
,
residual_channels
=
None
,
initial_skip_conv
=
False
):
"""Stack Residual blocks one after the other.
Args:
...
...
@@ -184,6 +214,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
initial_stride: int, the stride of the initial residual block.
residual_channels: int, the desired number of output channels in the
intermediate residual blocks. If not specifed, we use out_channels.
initial_skip_conv: bool, if set, the first residual block uses a skip
convolution. This is useful when the number of channels in the input
are not the same as residual_channels.
Returns:
blocks: A list of residual blocks to be applied in sequence.
...
...
@@ -196,16 +229,34 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
residual_channels
=
out_channels
for
i
in
range
(
num_blocks
-
1
):
# Only use the stride at the first block so we don't repeatedly downsample
# the input
stride
=
initial_stride
if
i
==
0
else
1
# If the stide is more than 1, we cannot use an identity layer for the
# skip connection and are forced to use a conv for the skip connection.
skip_conv
=
stride
>
1
if
i
==
0
and
initial_skip_conv
:
skip_conv
=
True
blocks
.
append
(
ResidualBlock
(
out_channels
=
residual_channels
,
stride
=
stride
,
skip_conv
=
skip_conv
)
)
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
ResidualBlock
(
out_channels
=
out_channels
,
skip_conv
=
skip_conv
))
if
num_blocks
==
1
:
# If there is only 1 block, the for loop above is not run,
# therefore we honor the requested stride in the last residual block
stride
=
initial_stride
# We are forced to use a conv in the skip connection if stride > 1
skip_conv
=
stride
>
1
else
:
stride
=
1
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
ResidualBlock
(
out_channels
=
out_channels
,
skip_conv
=
skip_conv
,
stride
=
stride
))
return
blocks
...
...
@@ -222,7 +273,8 @@ def _apply_blocks(inputs, blocks):
class
EncoderDecoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""An encoder-decoder block which recursively defines the hourglass network."""
def
__init__
(
self
,
num_stages
,
channel_dims
,
blocks_per_stage
):
def
__init__
(
self
,
num_stages
,
channel_dims
,
blocks_per_stage
,
stagewise_downsample
=
True
,
encoder_decoder_shortcut
=
True
):
"""Initializes the encoder-decoder block.
Args:
...
...
@@ -237,6 +289,10 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
blocks_per_stage: int list, number of residual blocks to use at each
stage. `blocks_per_stage[0]` defines the number of blocks at the
current stage and `blocks_per_stage[1:]` is used at further stages.
stagewise_downsample: bool, whether or not to downsample before passing
inputs to the next stage.
encoder_decoder_shortcut: bool, whether or not to use shortcut
connections between encoder and decoder.
"""
super
(
EncoderDecoderBlock
,
self
).
__init__
()
...
...
@@ -244,17 +300,26 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
out_channels
=
channel_dims
[
0
]
out_channels_downsampled
=
channel_dims
[
1
]
self
.
encoder_block1
=
_make_repeated_residual_blocks
(
out_channels
=
out_channels
,
num_blocks
=
blocks_per_stage
[
0
],
initial_stride
=
1
)
self
.
encoder_decoder_shortcut
=
encoder_decoder_shortcut
if
encoder_decoder_shortcut
:
self
.
merge_features
=
tf
.
keras
.
layers
.
Add
()
self
.
encoder_block1
=
_make_repeated_residual_blocks
(
out_channels
=
out_channels
,
num_blocks
=
blocks_per_stage
[
0
],
initial_stride
=
1
)
initial_stride
=
2
if
stagewise_downsample
else
1
self
.
encoder_block2
=
_make_repeated_residual_blocks
(
out_channels
=
out_channels_downsampled
,
num_blocks
=
blocks_per_stage
[
0
],
initial_stride
=
2
)
num_blocks
=
blocks_per_stage
[
0
],
initial_stride
=
initial_stride
,
initial_skip_conv
=
out_channels
!=
out_channels_downsampled
)
if
num_stages
>
1
:
self
.
inner_block
=
[
EncoderDecoderBlock
(
num_stages
-
1
,
channel_dims
[
1
:],
blocks_per_stage
[
1
:])
blocks_per_stage
[
1
:],
stagewise_downsample
=
stagewise_downsample
,
encoder_decoder_shortcut
=
encoder_decoder_shortcut
)
]
else
:
self
.
inner_block
=
_make_repeated_residual_blocks
(
...
...
@@ -264,13 +329,13 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
self
.
decoder_block
=
_make_repeated_residual_blocks
(
residual_channels
=
out_channels_downsampled
,
out_channels
=
out_channels
,
num_blocks
=
blocks_per_stage
[
0
])
self
.
upsample
=
tf
.
keras
.
layers
.
UpSampling2D
(
2
)
self
.
merge_features
=
tf
.
keras
.
layers
.
Add
(
)
self
.
upsample
=
tf
.
keras
.
layers
.
UpSampling2D
(
initial_stride
)
def
call
(
self
,
inputs
):
encoded_outputs
=
_apply_blocks
(
inputs
,
self
.
encoder_block1
)
if
self
.
encoder_decoder_shortcut
:
encoded_outputs
=
_apply_blocks
(
inputs
,
self
.
encoder_block1
)
encoded_downsampled_outputs
=
_apply_blocks
(
inputs
,
self
.
encoder_block2
)
inner_block_outputs
=
_apply_blocks
(
encoded_downsampled_outputs
,
self
.
inner_block
)
...
...
@@ -278,48 +343,68 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
decoded_outputs
=
_apply_blocks
(
inner_block_outputs
,
self
.
decoder_block
)
upsampled_outputs
=
self
.
upsample
(
decoded_outputs
)
return
self
.
merge_features
([
encoded_outputs
,
upsampled_outputs
])
if
self
.
encoder_decoder_shortcut
:
return
self
.
merge_features
([
encoded_outputs
,
upsampled_outputs
])
else
:
return
upsampled_outputs
class
HourglassNetwork
(
tf
.
keras
.
Model
):
"""The hourglass network."""
def
__init__
(
self
,
num_stages
,
channel_dims
,
blocks_per_stage
,
num_hourglasses
):
def
__init__
(
self
,
num_stages
,
input_channel_dims
,
channel_dims_per_stage
,
blocks_per_stage
,
num_hourglasses
,
initial_downsample
=
True
,
stagewise_downsample
=
True
,
encoder_decoder_shortcut
=
True
):
"""Intializes the feature extractor.
Args:
num_stages: int, Number of stages in the network. At each stage we have 2
encoder and 1 decoder blocks. The second encoder block downsamples the
input.
channel_dims: int list, the output channel dimensions of stages in
the network. `channel_dims[0]` and `channel_dims[1]` are used to define
the initial downsampling block. `channel_dims[1:]` is used to define
the hourglass network(s) which follow(s).
input_channel_dims: int, the number of channels in the input conv blocks.
channel_dims_per_stage: int list, the output channel dimensions of each
stage in the hourglass network.
blocks_per_stage: int list, number of residual blocks to use at each
stage in the hourglass network
num_hourglasses: int, number of hourglas networks to stack
sequentially.
initial_downsample: bool, if set, downsamples the input by a factor of 4
before applying the rest of the network. Downsampling is done with a 7x7
convolution kernel, otherwise a 3x3 kernel is used.
stagewise_downsample: bool, whether or not to downsample before passing
inputs to the next stage.
encoder_decoder_shortcut: bool, whether or not to use shortcut
connections between encoder and decoder.
"""
super
(
HourglassNetwork
,
self
).
__init__
()
self
.
num_hourglasses
=
num_hourglasses
self
.
downsample_input
=
InputDownsampleBlock
(
out_channels_initial_conv
=
channel_dims
[
0
],
out_channels_residual_block
=
channel_dims
[
1
]
)
self
.
initial_downsample
=
initial_downsample
if
initial_downsample
:
self
.
downsample_input
=
InputDownsampleBlock
(
out_channels_initial_conv
=
input_channel_dims
,
out_channels_residual_block
=
channel_dims_per_stage
[
0
]
)
else
:
self
.
conv_input
=
InputConvBlock
(
out_channels_initial_conv
=
input_channel_dims
,
out_channels_residual_block
=
channel_dims_per_stage
[
0
]
)
self
.
hourglass_network
=
[]
self
.
output_conv
=
[]
for
_
in
range
(
self
.
num_hourglasses
):
self
.
hourglass_network
.
append
(
EncoderDecoderBlock
(
num_stages
=
num_stages
,
channel_dims
=
channel_dims
[
1
:],
blocks_per_stage
=
blocks_per_stage
)
num_stages
=
num_stages
,
channel_dims
=
channel_dims_per_stage
,
blocks_per_stage
=
blocks_per_stage
,
stagewise_downsample
=
stagewise_downsample
,
encoder_decoder_shortcut
=
encoder_decoder_shortcut
)
)
self
.
output_conv
.
append
(
ConvolutionalBlock
(
kernel_size
=
3
,
out_channels
=
channel_dims
[
1
])
ConvolutionalBlock
(
kernel_size
=
3
,
out_channels
=
channel_dims_per_stage
[
0
])
)
self
.
intermediate_conv1
=
[]
...
...
@@ -329,21 +414,25 @@ class HourglassNetwork(tf.keras.Model):
for
_
in
range
(
self
.
num_hourglasses
-
1
):
self
.
intermediate_conv1
.
append
(
ConvolutionalBlock
(
kernel_size
=
1
,
out_channels
=
channel_dims
[
1
],
relu
=
False
)
kernel_size
=
1
,
out_channels
=
channel_dims
_per_stage
[
0
],
relu
=
False
)
)
self
.
intermediate_conv2
.
append
(
ConvolutionalBlock
(
kernel_size
=
1
,
out_channels
=
channel_dims
[
1
],
relu
=
False
)
kernel_size
=
1
,
out_channels
=
channel_dims
_per_stage
[
0
],
relu
=
False
)
)
self
.
intermediate_residual
.
append
(
ResidualBlock
(
out_channels
=
channel_dims
[
1
])
ResidualBlock
(
out_channels
=
channel_dims
_per_stage
[
0
])
)
self
.
intermediate_relu
=
tf
.
keras
.
layers
.
ReLU
()
def
call
(
self
,
inputs
):
inputs
=
self
.
downsample_input
(
inputs
)
if
self
.
initial_downsample
:
inputs
=
self
.
downsample_input
(
inputs
)
else
:
inputs
=
self
.
conv_input
(
inputs
)
outputs
=
[]
for
i
in
range
(
self
.
num_hourglasses
):
...
...
@@ -372,12 +461,164 @@ class HourglassNetwork(tf.keras.Model):
return
self
.
num_hourglasses
def
_layer_depth
(
layer
):
"""Compute depth of Conv/Residual blocks or lists of them."""
if
isinstance
(
layer
,
list
):
return
sum
([
_layer_depth
(
l
)
for
l
in
layer
])
elif
isinstance
(
layer
,
ConvolutionalBlock
):
return
1
elif
isinstance
(
layer
,
ResidualBlock
):
return
2
else
:
raise
ValueError
(
'Unknown layer - {}'
.
format
(
layer
))
def
_encoder_decoder_depth
(
network
):
"""Helper function to compute depth of encoder-decoder blocks."""
encoder_block2_layers
=
_layer_depth
(
network
.
encoder_block2
)
decoder_block_layers
=
_layer_depth
(
network
.
decoder_block
)
if
isinstance
(
network
.
inner_block
[
0
],
EncoderDecoderBlock
):
assert
len
(
network
.
inner_block
)
==
1
,
'Inner block is expected as length 1.'
inner_block_layers
=
_encoder_decoder_depth
(
network
.
inner_block
[
0
])
return
inner_block_layers
+
encoder_block2_layers
+
decoder_block_layers
elif
isinstance
(
network
.
inner_block
[
0
],
ResidualBlock
):
return
(
encoder_block2_layers
+
decoder_block_layers
+
_layer_depth
(
network
.
inner_block
))
else
:
raise
ValueError
(
'Unknown inner block type.'
)
def
hourglass_depth
(
network
):
"""Helper function to verify depth of hourglass backbone."""
input_conv_layers
=
3
# 1 ResidualBlock and 1 ConvBlock
# Only intermediate_conv2 and intermediate_residual are applied before
# sending inputs to the later stages.
intermediate_layers
=
(
_layer_depth
(
network
.
intermediate_conv2
)
+
_layer_depth
(
network
.
intermediate_residual
)
)
# network.output_conv is applied before sending input to the later stages
output_layers
=
_layer_depth
(
network
.
output_conv
)
encoder_decoder_layers
=
sum
(
_encoder_decoder_depth
(
net
)
for
net
in
network
.
hourglass_network
)
return
(
input_conv_layers
+
encoder_decoder_layers
+
intermediate_layers
+
output_layers
)
def
hourglass_104
():
"""The Hourglass-104 backbone."""
"""The Hourglass-104 backbone.
The architecture parameters are taken from [1].
Returns:
network: An HourglassNetwork object implementing the Hourglass-104
backbone.
[1]: https://arxiv.org/abs/1904.07850
"""
return
HourglassNetwork
(
channel_dims
=
[
128
,
256
,
256
,
384
,
384
,
384
,
512
],
input_channel_dims
=
128
,
channel_dims_per_stage
=
[
256
,
256
,
384
,
384
,
384
,
512
],
num_hourglasses
=
2
,
num_stages
=
5
,
blocks_per_stage
=
[
2
,
2
,
2
,
2
,
2
,
4
],
)
def
single_stage_hourglass
(
input_channel_dims
,
channel_dims_per_stage
,
blocks_per_stage
,
initial_downsample
=
True
,
stagewise_downsample
=
True
,
encoder_decoder_shortcut
=
True
):
assert
len
(
channel_dims_per_stage
)
==
len
(
blocks_per_stage
)
return
HourglassNetwork
(
input_channel_dims
=
input_channel_dims
,
channel_dims_per_stage
=
channel_dims_per_stage
,
num_hourglasses
=
1
,
num_stages
=
len
(
channel_dims_per_stage
)
-
1
,
blocks_per_stage
=
blocks_per_stage
,
initial_downsample
=
initial_downsample
,
stagewise_downsample
=
stagewise_downsample
,
encoder_decoder_shortcut
=
encoder_decoder_shortcut
)
def
hourglass_10
(
num_channels
,
initial_downsample
=
True
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
initial_downsample
=
initial_downsample
,
blocks_per_stage
=
[
1
,
1
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
])
def
hourglass_20
(
num_channels
,
initial_downsample
=
True
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
initial_downsample
=
initial_downsample
,
blocks_per_stage
=
[
1
,
2
,
2
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
,
nc
*
3
])
def
hourglass_32
(
num_channels
,
initial_downsample
=
True
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
initial_downsample
=
initial_downsample
,
blocks_per_stage
=
[
2
,
2
,
2
,
2
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
,
nc
*
3
,
nc
*
3
])
def
hourglass_52
(
num_channels
,
initial_downsample
=
True
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
initial_downsample
=
initial_downsample
,
blocks_per_stage
=
[
2
,
2
,
2
,
2
,
2
,
4
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
,
nc
*
3
,
nc
*
3
,
nc
*
3
,
nc
*
4
])
def
hourglass_100
(
num_channels
,
initial_downsample
=
True
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
initial_downsample
=
initial_downsample
,
blocks_per_stage
=
[
4
,
4
,
4
,
4
,
4
,
8
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
,
nc
*
3
,
nc
*
3
,
nc
*
3
,
nc
*
4
])
def
hourglass_20_uniform_size
(
num_channels
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
blocks_per_stage
=
[
1
,
2
,
2
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
,
nc
*
3
],
initial_downsample
=
False
,
stagewise_downsample
=
False
)
def
hourglass_20_no_shortcut
(
num_channels
):
nc
=
num_channels
return
single_stage_hourglass
(
input_channel_dims
=
nc
,
blocks_per_stage
=
[
1
,
2
,
2
],
channel_dims_per_stage
=
[
nc
*
2
,
nc
*
2
,
nc
*
3
],
initial_downsample
=
False
,
encoder_decoder_shortcut
=
False
)
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
View file @
b0ccdb11
...
...
@@ -78,6 +78,12 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
8
,
8
,
8
))
def
test_input_conv_block
(
self
):
layer
=
hourglass
.
InputConvBlock
(
out_channels_initial_conv
=
4
,
out_channels_residual_block
=
8
)
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
32
,
32
,
8
))
def
test_encoder_decoder_block
(
self
):
layer
=
hourglass
.
EncoderDecoderBlock
(
...
...
@@ -89,12 +95,64 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
def
test_hourglass_feature_extractor
(
self
):
model
=
hourglass
.
HourglassNetwork
(
num_stages
=
4
,
blocks_per_stage
=
[
2
,
3
,
4
,
5
,
6
],
channel_dims
=
[
4
,
6
,
8
,
10
,
12
,
14
],
num_hourglasses
=
2
)
num_stages
=
4
,
blocks_per_stage
=
[
2
,
3
,
4
,
5
,
6
],
input_channel_dims
=
4
,
channel_dims
_per_stage
=
[
6
,
8
,
10
,
12
,
14
],
num_hourglasses
=
2
)
outputs
=
model
(
np
.
zeros
((
2
,
64
,
64
,
3
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
16
,
16
,
6
))
self
.
assertEqual
(
outputs
[
1
].
shape
,
(
2
,
16
,
16
,
6
))
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
HourglassDepthTest
(
tf
.
test
.
TestCase
):
def
test_hourglass_104
(
self
):
net
=
hourglass
.
hourglass_104
()
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
104
)
def
test_hourglass_10
(
self
):
net
=
hourglass
.
hourglass_10
(
2
,
initial_downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
10
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_20
(
self
):
net
=
hourglass
.
hourglass_20
(
2
,
initial_downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
20
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_32
(
self
):
net
=
hourglass
.
hourglass_32
(
2
,
initial_downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
32
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_52
(
self
):
net
=
hourglass
.
hourglass_52
(
2
,
initial_downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
52
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_20_uniform_size
(
self
):
net
=
hourglass
.
hourglass_20_uniform_size
(
2
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
20
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_100
(
self
):
net
=
hourglass
.
hourglass_100
(
2
,
initial_downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
100
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
research/object_detection/protos/center_net.proto
View file @
b0ccdb11
...
...
@@ -19,6 +19,9 @@ message CenterNet {
// Image resizer for preprocessing the input image.
optional
ImageResizer
image_resizer
=
3
;
// If set, all task heads will be constructed with separable convolutions.
optional
bool
use_depthwise
=
13
[
default
=
false
];
// Parameters which are related to object detection task.
message
ObjectDetection
{
// The original fields are moved to ObjectCenterParams or deleted.
...
...
@@ -245,6 +248,21 @@ message CenterNet {
}
optional
TrackEstimation
track_estimation_task
=
10
;
// Temporal offset prediction head similar to CenterTrack.
// Currently our implementation adopts LSTM, different from original paper.
// See go/lstd-centernet for more details.
// Tracking Objects as Points [3]
// [3]: https://arxiv.org/abs/2004.01177
message
TemporalOffsetEstimation
{
// Weight of the task loss. The total loss of the model will be the
// summation of task losses weighted by the weights.
optional
float
task_loss_weight
=
1
[
default
=
1.0
];
// Localization loss configuration for offset loss.
optional
LocalizationLoss
localization_loss
=
2
;
}
optional
TemporalOffsetEstimation
temporal_offset_task
=
12
;
}
...
...
@@ -263,4 +281,9 @@ message CenterNetFeatureExtractor {
// If set, will change channel order to be [blue, green, red]. This can be
// useful to be compatible with some pre-trained feature extractors.
optional
bool
bgr_ordering
=
4
[
default
=
false
];
// If set, the feature upsampling layers will be constructed with
// separable convolutions. This is typically applied to feature pyramid
// network if any.
optional
bool
use_depthwise
=
5
[
default
=
false
];
}
research/slim/datasets/imagenet.py
View file @
b0ccdb11
...
...
@@ -86,7 +86,7 @@ def create_readable_names_for_imagenet_labels():
"""
# pylint: disable=g-line-too-long
base_url
=
'https://raw.githubusercontent.com/tensorflow/models/master/research/
inception/inception
/data/'
base_url
=
'https://raw.githubusercontent.com/tensorflow/models/master/research/
slim
/data
sets
/'
synset_url
=
'{}/imagenet_lsvrc_2015_synsets.txt'
.
format
(
base_url
)
synset_to_human_url
=
'{}/imagenet_metadata.txt'
.
format
(
base_url
)
...
...
research/slim/nets/mobilenet/mobilenet_example.ipynb
View file @
b0ccdb11
...
...
@@ -116,7 +116,7 @@
"source": [
"from __future__ import print_function\n",
"from IPython import display \n",
"checkpoint_name = 'mobilenet_v2_1.0_224' #@param\n",
"
base_name =
checkpoint_name = 'mobilenet_v2_1.0_224' #@param\n",
"url = 'https://storage.googleapis.com/mobilenet_v2/checkpoints/' + checkpoint_name + '.tgz'\n",
"print('Downloading from ', url)\n",
"!wget {url}\n",
...
...
research/slim/nets/mobilenet_v1.py
View file @
b0ccdb11
...
...
@@ -155,7 +155,7 @@ def _fixed_padding(inputs, kernel_size, rate=1):
input, either intact (if kernel_size == 1) or padded (if kernel_size > 1).
"""
kernel_size_effective
=
[
kernel_size
[
0
]
+
(
kernel_size
[
0
]
-
1
)
*
(
rate
-
1
),
kernel_size
[
0
]
+
(
kernel_size
[
0
]
-
1
)
*
(
rate
-
1
)]
kernel_size
[
1
]
+
(
kernel_size
[
1
]
-
1
)
*
(
rate
-
1
)]
pad_total
=
[
kernel_size_effective
[
0
]
-
1
,
kernel_size_effective
[
1
]
-
1
]
pad_beg
=
[
pad_total
[
0
]
//
2
,
pad_total
[
1
]
//
2
]
pad_end
=
[
pad_total
[
0
]
-
pad_beg
[
0
],
pad_total
[
1
]
-
pad_beg
[
1
]]
...
...
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment