Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a8518117
Commit
a8518117
authored
Sep 10, 2020
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Sep 10, 2020
Browse files
Make downsampling optional in hourglass.
PiperOrigin-RevId: 331013782
parent
9d2a7242
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
6 deletions
+55
-6
research/object_detection/models/keras_models/hourglass_network.py
...object_detection/models/keras_models/hourglass_network.py
+49
-6
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
...tection/models/keras_models/hourglass_network_tf2_test.py
+6
-0
No files found.
research/object_detection/models/keras_models/hourglass_network.py
View file @
a8518117
...
@@ -174,6 +174,36 @@ class InputDownsampleBlock(tf.keras.layers.Layer):
...
@@ -174,6 +174,36 @@ class InputDownsampleBlock(tf.keras.layers.Layer):
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
class
InputConvBlock
(
tf
.
keras
.
layers
.
Layer
):
"""Block for the initial feature convolution.
This block is used in the hourglass network when we don't want to downsample
the input.
"""
def
__init__
(
self
,
out_channels_initial_conv
,
out_channels_residual_block
):
"""Initializes the downsample block.
Args:
out_channels_initial_conv: int, the desired number of output channels
in the initial conv layer.
out_channels_residual_block: int, the desired number of output channels
in the underlying residual block.
"""
super
(
InputConvBlock
,
self
).
__init__
()
# TODO(vighneshb) explore if 3x3 works here.
self
.
conv_block
=
ConvolutionalBlock
(
kernel_size
=
7
,
out_channels
=
out_channels_initial_conv
,
stride
=
1
,
padding
=
'valid'
)
self
.
residual_block
=
ResidualBlock
(
out_channels
=
out_channels_residual_block
,
stride
=
1
,
skip_conv
=
True
)
def
call
(
self
,
inputs
):
return
self
.
residual_block
(
self
.
conv_block
(
inputs
))
def
_make_repeated_residual_blocks
(
out_channels
,
num_blocks
,
def
_make_repeated_residual_blocks
(
out_channels
,
num_blocks
,
initial_stride
=
1
,
residual_channels
=
None
):
initial_stride
=
1
,
residual_channels
=
None
):
"""Stack Residual blocks one after the other.
"""Stack Residual blocks one after the other.
...
@@ -285,7 +315,7 @@ class HourglassNetwork(tf.keras.Model):
...
@@ -285,7 +315,7 @@ class HourglassNetwork(tf.keras.Model):
"""The hourglass network."""
"""The hourglass network."""
def
__init__
(
self
,
num_stages
,
channel_dims
,
blocks_per_stage
,
def
__init__
(
self
,
num_stages
,
channel_dims
,
blocks_per_stage
,
num_hourglasses
):
num_hourglasses
,
downsample
=
True
):
"""Intializes the feature extractor.
"""Intializes the feature extractor.
Args:
Args:
...
@@ -300,15 +330,24 @@ class HourglassNetwork(tf.keras.Model):
...
@@ -300,15 +330,24 @@ class HourglassNetwork(tf.keras.Model):
stage in the hourglass network
stage in the hourglass network
num_hourglasses: int, number of hourglas networks to stack
num_hourglasses: int, number of hourglas networks to stack
sequentially.
sequentially.
downsample: bool, if set, downsamples the input by a factor of 4 before
applying the rest of the network.
"""
"""
super
(
HourglassNetwork
,
self
).
__init__
()
super
(
HourglassNetwork
,
self
).
__init__
()
self
.
num_hourglasses
=
num_hourglasses
self
.
num_hourglasses
=
num_hourglasses
self
.
downsample_input
=
InputDownsampleBlock
(
self
.
downsample
=
downsample
out_channels_initial_conv
=
channel_dims
[
0
],
if
downsample
:
out_channels_residual_block
=
channel_dims
[
1
]
self
.
downsample_input
=
InputDownsampleBlock
(
)
out_channels_initial_conv
=
channel_dims
[
0
],
out_channels_residual_block
=
channel_dims
[
1
]
)
else
:
self
.
conv_input
=
InputConvBlock
(
out_channels_initial_conv
=
channel_dims
[
0
],
out_channels_residual_block
=
channel_dims
[
1
]
)
self
.
hourglass_network
=
[]
self
.
hourglass_network
=
[]
self
.
output_conv
=
[]
self
.
output_conv
=
[]
...
@@ -343,7 +382,11 @@ class HourglassNetwork(tf.keras.Model):
...
@@ -343,7 +382,11 @@ class HourglassNetwork(tf.keras.Model):
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
inputs
=
self
.
downsample_input
(
inputs
)
if
self
.
downsample
:
inputs
=
self
.
downsample_input
(
inputs
)
else
:
inputs
=
self
.
conv_input
(
inputs
)
outputs
=
[]
outputs
=
[]
for
i
in
range
(
self
.
num_hourglasses
):
for
i
in
range
(
self
.
num_hourglasses
):
...
...
research/object_detection/models/keras_models/hourglass_network_tf2_test.py
View file @
a8518117
...
@@ -78,6 +78,12 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -78,6 +78,12 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
8
,
8
,
8
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
8
,
8
,
8
))
def
test_input_conv_block
(
self
):
layer
=
hourglass
.
InputConvBlock
(
out_channels_initial_conv
=
4
,
out_channels_residual_block
=
8
)
output
=
layer
(
np
.
zeros
((
2
,
32
,
32
,
8
),
dtype
=
np
.
float32
))
self
.
assertEqual
(
output
.
shape
,
(
2
,
32
,
32
,
8
))
def
test_encoder_decoder_block
(
self
):
def
test_encoder_decoder_block
(
self
):
layer
=
hourglass
.
EncoderDecoderBlock
(
layer
=
hourglass
.
EncoderDecoderBlock
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment