Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9d2a7242
Commit
9d2a7242
authored
Sep 10, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Sep 10, 2020
Browse files
Support for hourglass-10,20,32 and 52 and function to compute hourglass depth.
PiperOrigin-RevId: 330992374
parent
9b8b13e8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
123 additions
and
1 deletion
+123
-1
research/object_detection/models/keras_models/hourglass_network.py
...object_detection/models/keras_models/hourglass_network.py
+98
-1
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
...tection/models/keras_models/hourglass_network_tf2_test.py
+25
-0
No files found.
research/object_detection/models/keras_models/hourglass_network.py
View file @
9d2a7242
...
@@ -372,8 +372,76 @@ class HourglassNetwork(tf.keras.Model):
...
@@ -372,8 +372,76 @@ class HourglassNetwork(tf.keras.Model):
return
self
.
num_hourglasses
return
self
.
num_hourglasses
def
_layer_depth
(
layer
):
"""Compute depth of Conv/Residual blocks or lists of them."""
if
isinstance
(
layer
,
list
):
return
sum
([
_layer_depth
(
l
)
for
l
in
layer
])
elif
isinstance
(
layer
,
ConvolutionalBlock
):
return
1
elif
isinstance
(
layer
,
ResidualBlock
):
return
2
else
:
raise
ValueError
(
'Unknown layer - {}'
.
format
(
layer
))
def
_encoder_decoder_depth
(
network
):
"""Helper function to compute depth of encoder-decoder blocks."""
encoder_block2_layers
=
_layer_depth
(
network
.
encoder_block2
)
decoder_block_layers
=
_layer_depth
(
network
.
decoder_block
)
if
isinstance
(
network
.
inner_block
[
0
],
EncoderDecoderBlock
):
assert
len
(
network
.
inner_block
)
==
1
,
'Inner block is expected as length 1.'
inner_block_layers
=
_encoder_decoder_depth
(
network
.
inner_block
[
0
])
return
inner_block_layers
+
encoder_block2_layers
+
decoder_block_layers
elif
isinstance
(
network
.
inner_block
[
0
],
ResidualBlock
):
return
(
encoder_block2_layers
+
decoder_block_layers
+
_layer_depth
(
network
.
inner_block
))
else
:
raise
ValueError
(
'Unknown inner block type.'
)
def
hourglass_depth
(
network
):
"""Helper function to verify depth of hourglass backbone."""
input_conv_layers
=
3
# 1 ResidualBlock and 1 ConvBlock
# Only intermediate_conv2 and intermediate_residual are applied before
# sending inputs to the later stages.
intermediate_layers
=
(
_layer_depth
(
network
.
intermediate_conv2
)
+
_layer_depth
(
network
.
intermediate_residual
)
)
# network.output_conv is applied before sending input to the later stages
output_layers
=
_layer_depth
(
network
.
output_conv
)
encoder_decoder_layers
=
sum
(
_encoder_decoder_depth
(
net
)
for
net
in
network
.
hourglass_network
)
return
(
input_conv_layers
+
encoder_decoder_layers
+
intermediate_layers
+
output_layers
)
def
hourglass_104
():
def
hourglass_104
():
"""The Hourglass-104 backbone."""
"""The Hourglass-104 backbone.
The architecture parameters are taken from [1].
Returns:
network: An HourglassNetwork object implementing the Hourglass-104
backbone.
[1]: https://arxiv.org/abs/1904.07850
"""
return
HourglassNetwork
(
return
HourglassNetwork
(
channel_dims
=
[
128
,
256
,
256
,
384
,
384
,
384
,
512
],
channel_dims
=
[
128
,
256
,
256
,
384
,
384
,
384
,
512
],
...
@@ -381,3 +449,32 @@ def hourglass_104():
...
@@ -381,3 +449,32 @@ def hourglass_104():
num_stages
=
5
,
num_stages
=
5
,
blocks_per_stage
=
[
2
,
2
,
2
,
2
,
2
,
4
],
blocks_per_stage
=
[
2
,
2
,
2
,
2
,
2
,
4
],
)
)
def
single_stage_hourglass
(
blocks_per_stage
,
num_channels
):
nc
=
num_channels
channel_dims
=
[
nc
,
nc
*
2
,
nc
*
2
,
nc
*
3
,
nc
*
3
,
nc
*
3
,
nc
*
4
]
num_stages
=
len
(
blocks_per_stage
)
-
1
channel_dims
=
channel_dims
[:
num_stages
+
2
]
return
HourglassNetwork
(
channel_dims
=
channel_dims
,
num_hourglasses
=
1
,
num_stages
=
num_stages
,
blocks_per_stage
=
blocks_per_stage
,
)
def
hourglass_10
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
],
num_channels
)
def
hourglass_20
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
,
1
,
2
],
num_channels
)
def
hourglass_32
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
,
2
,
2
,
2
],
num_channels
)
def
hourglass_52
(
num_channels
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
,
2
,
4
],
num_channels
)
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
View file @
9d2a7242
...
@@ -96,5 +96,30 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -96,5 +96,30 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
outputs
[
1
].
shape
,
(
2
,
16
,
16
,
6
))
self
.
assertEqual
(
outputs
[
1
].
shape
,
(
2
,
16
,
16
,
6
))
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
HourglassDepthTest
(
tf
.
test
.
TestCase
):
def
test_hourglass_104
(
self
):
net
=
hourglass
.
hourglass_104
()
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
104
)
def
test_hourglass_10
(
self
):
net
=
hourglass
.
hourglass_10
(
2
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
10
)
def
test_hourglass_20
(
self
):
net
=
hourglass
.
hourglass_20
(
2
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
20
)
def
test_hourglass_32
(
self
):
net
=
hourglass
.
hourglass_32
(
2
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
32
)
def
test_hourglass_52
(
self
):
net
=
hourglass
.
hourglass_52
(
2
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
52
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment