Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3bac1426
Commit
3bac1426
authored
Sep 11, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Sep 11, 2020
Browse files
Fixes and tests for hourglass variants.
PiperOrigin-RevId: 331166835
parent
643d492b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
15 deletions
+44
-15
research/object_detection/models/keras_models/hourglass_network.py
...object_detection/models/keras_models/hourglass_network.py
+27
-11
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
...tection/models/keras_models/hourglass_network_tf2_test.py
+17
-4
No files found.
research/object_detection/models/keras_models/hourglass_network.py
View file @
3bac1426
...
...
@@ -226,7 +226,12 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
residual_channels
=
out_channels
for
i
in
range
(
num_blocks
-
1
):
# Only use the stride at the first block so we don't repeatedly downsample
# the input
stride
=
initial_stride
if
i
==
0
else
1
# If the stide is more than 1, we cannot use an identity layer for the
# skip connection and are forced to use a conv for the skip connection.
skip_conv
=
stride
>
1
blocks
.
append
(
...
...
@@ -234,8 +239,18 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
skip_conv
=
skip_conv
)
)
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
ResidualBlock
(
out_channels
=
out_channels
,
skip_conv
=
skip_conv
))
if
num_blocks
==
1
:
# If there is only 1 block, the for loop above is not run,
# therefore we honor the requested stride in the last residual block
stride
=
initial_stride
# We are forced to use a conv in the skip connection if stride > 1
skip_conv
=
stride
>
1
else
:
stride
=
1
skip_conv
=
residual_channels
!=
out_channels
blocks
.
append
(
ResidualBlock
(
out_channels
=
out_channels
,
skip_conv
=
skip_conv
,
stride
=
stride
))
return
blocks
...
...
@@ -494,7 +509,7 @@ def hourglass_104():
)
def
single_stage_hourglass
(
blocks_per_stage
,
num_channels
):
def
single_stage_hourglass
(
blocks_per_stage
,
num_channels
,
downsample
=
True
):
nc
=
num_channels
channel_dims
=
[
nc
,
nc
*
2
,
nc
*
2
,
nc
*
3
,
nc
*
3
,
nc
*
3
,
nc
*
4
]
num_stages
=
len
(
blocks_per_stage
)
-
1
...
...
@@ -504,20 +519,21 @@ def single_stage_hourglass(blocks_per_stage, num_channels):
num_hourglasses
=
1
,
num_stages
=
num_stages
,
blocks_per_stage
=
blocks_per_stage
,
downsample
=
downsample
)
def
hourglass_10
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
],
num_channels
)
def
hourglass_10
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
1
,
1
],
num_channels
,
downsample
)
def
hourglass_20
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
,
1
,
2
],
num_channels
)
def
hourglass_20
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
1
,
2
,
2
],
num_channels
,
downsample
)
def
hourglass_32
(
num_channels
):
return
single_stage_hourglass
([
1
,
1
,
2
,
2
,
2
],
num_channels
)
def
hourglass_32
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
],
num_channels
,
downsample
)
def
hourglass_52
(
num_channels
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
,
2
,
4
],
num_channels
)
def
hourglass_52
(
num_channels
,
downsample
=
True
):
return
single_stage_hourglass
([
2
,
2
,
2
,
2
,
2
,
4
],
num_channels
,
downsample
)
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
View file @
3bac1426
...
...
@@ -111,21 +111,34 @@ class HourglassDepthTest(tf.test.TestCase):
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
104
)
def
test_hourglass_10
(
self
):
net
=
hourglass
.
hourglass_10
(
2
)
net
=
hourglass
.
hourglass_10
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
10
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_20
(
self
):
net
=
hourglass
.
hourglass_20
(
2
)
net
=
hourglass
.
hourglass_20
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
20
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_32
(
self
):
net
=
hourglass
.
hourglass_32
(
2
)
net
=
hourglass
.
hourglass_32
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
32
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
def
test_hourglass_52
(
self
):
net
=
hourglass
.
hourglass_52
(
2
)
net
=
hourglass
.
hourglass_52
(
2
,
downsample
=
False
)
self
.
assertEqual
(
hourglass
.
hourglass_depth
(
net
),
52
)
outputs
=
net
(
tf
.
zeros
((
2
,
32
,
32
,
3
)))
self
.
assertEqual
(
outputs
[
0
].
shape
,
(
2
,
32
,
32
,
4
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment