Commit ee4b011b authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Support for more hourglass network configurations.

PiperOrigin-RevId: 333576935
parent 056859fa
......@@ -30,7 +30,8 @@ class CenterNetHourglassFeatureExtractorTest(test_case.TestCase):
net = hourglass_network.HourglassNetwork(
num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6],
channel_dims=[4, 6, 8, 10, 12, 14], num_hourglasses=2)
input_channel_dims=4, channel_dims_per_stage=[6, 8, 10, 12, 14],
num_hourglasses=2)
model = hourglass.CenterNetHourglassFeatureExtractor(net)
def graph_fn():
......
......@@ -193,9 +193,8 @@ class InputConvBlock(tf.keras.layers.Layer):
super(InputConvBlock, self).__init__()
# TODO(vighneshb) explore if 3x3 works here.
self.conv_block = ConvolutionalBlock(
kernel_size=7, out_channels=out_channels_initial_conv, stride=1,
kernel_size=3, out_channels=out_channels_initial_conv, stride=1,
padding='valid')
self.residual_block = ResidualBlock(
out_channels=out_channels_residual_block, stride=1, skip_conv=True)
......@@ -205,7 +204,8 @@ class InputConvBlock(tf.keras.layers.Layer):
def _make_repeated_residual_blocks(out_channels, num_blocks,
initial_stride=1, residual_channels=None):
initial_stride=1, residual_channels=None,
initial_skip_conv=False):
"""Stack Residual blocks one after the other.
Args:
......@@ -214,6 +214,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
initial_stride: int, the stride of the initial residual block.
residual_channels: int, the desired number of output channels in the
intermediate residual blocks. If not specifed, we use out_channels.
initial_skip_conv: bool, if set, the first residual block uses a skip
convolution. This is useful when the number of channels in the input
are not the same as residual_channels.
Returns:
blocks: A list of residual blocks to be applied in sequence.
......@@ -234,6 +237,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
# skip connection and are forced to use a conv for the skip connection.
skip_conv = stride > 1
if i == 0 and initial_skip_conv:
skip_conv = True
blocks.append(
ResidualBlock(out_channels=residual_channels, stride=stride,
skip_conv=skip_conv)
......@@ -267,7 +273,8 @@ def _apply_blocks(inputs, blocks):
class EncoderDecoderBlock(tf.keras.layers.Layer):
"""An encoder-decoder block which recursively defines the hourglass network."""
def __init__(self, num_stages, channel_dims, blocks_per_stage):
def __init__(self, num_stages, channel_dims, blocks_per_stage,
stagewise_downsample=True, encoder_decoder_shortcut=True):
"""Initializes the encoder-decoder block.
Args:
......@@ -282,6 +289,10 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
blocks_per_stage: int list, number of residual blocks to use at each
stage. `blocks_per_stage[0]` defines the number of blocks at the
current stage and `blocks_per_stage[1:]` is used at further stages.
stagewise_downsample: bool, whether or not to downsample before passing
inputs to the next stage.
encoder_decoder_shortcut: bool, whether or not to use shortcut
connections between encoder and decoder.
"""
super(EncoderDecoderBlock, self).__init__()
......@@ -289,17 +300,26 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
out_channels = channel_dims[0]
out_channels_downsampled = channel_dims[1]
self.encoder_block1 = _make_repeated_residual_blocks(
out_channels=out_channels, num_blocks=blocks_per_stage[0],
initial_stride=1)
self.encoder_decoder_shortcut = encoder_decoder_shortcut
if encoder_decoder_shortcut:
self.merge_features = tf.keras.layers.Add()
self.encoder_block1 = _make_repeated_residual_blocks(
out_channels=out_channels, num_blocks=blocks_per_stage[0],
initial_stride=1)
initial_stride = 2 if stagewise_downsample else 1
self.encoder_block2 = _make_repeated_residual_blocks(
out_channels=out_channels_downsampled,
num_blocks=blocks_per_stage[0], initial_stride=2)
num_blocks=blocks_per_stage[0], initial_stride=initial_stride,
initial_skip_conv=out_channels != out_channels_downsampled)
if num_stages > 1:
self.inner_block = [
EncoderDecoderBlock(num_stages - 1, channel_dims[1:],
blocks_per_stage[1:])
blocks_per_stage[1:],
stagewise_downsample=stagewise_downsample,
encoder_decoder_shortcut=encoder_decoder_shortcut)
]
else:
self.inner_block = _make_repeated_residual_blocks(
......@@ -309,13 +329,13 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
self.decoder_block = _make_repeated_residual_blocks(
residual_channels=out_channels_downsampled,
out_channels=out_channels, num_blocks=blocks_per_stage[0])
self.upsample = tf.keras.layers.UpSampling2D(2)
self.merge_features = tf.keras.layers.Add()
self.upsample = tf.keras.layers.UpSampling2D(initial_stride)
def call(self, inputs):
encoded_outputs = _apply_blocks(inputs, self.encoder_block1)
if self.encoder_decoder_shortcut:
encoded_outputs = _apply_blocks(inputs, self.encoder_block1)
encoded_downsampled_outputs = _apply_blocks(inputs, self.encoder_block2)
inner_block_outputs = _apply_blocks(
encoded_downsampled_outputs, self.inner_block)
......@@ -323,45 +343,53 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
decoded_outputs = _apply_blocks(inner_block_outputs, self.decoder_block)
upsampled_outputs = self.upsample(decoded_outputs)
return self.merge_features([encoded_outputs, upsampled_outputs])
if self.encoder_decoder_shortcut:
return self.merge_features([encoded_outputs, upsampled_outputs])
else:
return upsampled_outputs
class HourglassNetwork(tf.keras.Model):
"""The hourglass network."""
def __init__(self, num_stages, channel_dims, blocks_per_stage,
num_hourglasses, downsample=True):
def __init__(self, num_stages, input_channel_dims, channel_dims_per_stage,
blocks_per_stage, num_hourglasses, initial_downsample=True,
stagewise_downsample=True, encoder_decoder_shortcut=True):
"""Intializes the feature extractor.
Args:
num_stages: int, Number of stages in the network. At each stage we have 2
encoder and 1 decoder blocks. The second encoder block downsamples the
input.
channel_dims: int list, the output channel dimensions of stages in
the network. `channel_dims[0]` and `channel_dims[1]` are used to define
the initial downsampling block. `channel_dims[1:]` is used to define
the hourglass network(s) which follow(s).
input_channel_dims: int, the number of channels in the input conv blocks.
channel_dims_per_stage: int list, the output channel dimensions of each
stage in the hourglass network.
blocks_per_stage: int list, number of residual blocks to use at each
stage in the hourglass network
num_hourglasses: int, number of hourglas networks to stack
sequentially.
downsample: bool, if set, downsamples the input by a factor of 4 before
applying the rest of the network.
initial_downsample: bool, if set, downsamples the input by a factor of 4
before applying the rest of the network. Downsampling is done with a 7x7
convolution kernel, otherwise a 3x3 kernel is used.
stagewise_downsample: bool, whether or not to downsample before passing
inputs to the next stage.
encoder_decoder_shortcut: bool, whether or not to use shortcut
connections between encoder and decoder.
"""
super(HourglassNetwork, self).__init__()
self.num_hourglasses = num_hourglasses
self.downsample = downsample
if downsample:
self.initial_downsample = initial_downsample
if initial_downsample:
self.downsample_input = InputDownsampleBlock(
out_channels_initial_conv=channel_dims[0],
out_channels_residual_block=channel_dims[1]
out_channels_initial_conv=input_channel_dims,
out_channels_residual_block=channel_dims_per_stage[0]
)
else:
self.conv_input = InputConvBlock(
out_channels_initial_conv=channel_dims[0],
out_channels_residual_block=channel_dims[1]
out_channels_initial_conv=input_channel_dims,
out_channels_residual_block=channel_dims_per_stage[0]
)
self.hourglass_network = []
......@@ -369,11 +397,14 @@ class HourglassNetwork(tf.keras.Model):
for _ in range(self.num_hourglasses):
self.hourglass_network.append(
EncoderDecoderBlock(
num_stages=num_stages, channel_dims=channel_dims[1:],
blocks_per_stage=blocks_per_stage)
num_stages=num_stages, channel_dims=channel_dims_per_stage,
blocks_per_stage=blocks_per_stage,
stagewise_downsample=stagewise_downsample,
encoder_decoder_shortcut=encoder_decoder_shortcut)
)
self.output_conv.append(
ConvolutionalBlock(kernel_size=3, out_channels=channel_dims[1])
ConvolutionalBlock(kernel_size=3,
out_channels=channel_dims_per_stage[0])
)
self.intermediate_conv1 = []
......@@ -383,21 +414,21 @@ class HourglassNetwork(tf.keras.Model):
for _ in range(self.num_hourglasses - 1):
self.intermediate_conv1.append(
ConvolutionalBlock(
kernel_size=1, out_channels=channel_dims[1], relu=False)
kernel_size=1, out_channels=channel_dims_per_stage[0], relu=False)
)
self.intermediate_conv2.append(
ConvolutionalBlock(
kernel_size=1, out_channels=channel_dims[1], relu=False)
kernel_size=1, out_channels=channel_dims_per_stage[0], relu=False)
)
self.intermediate_residual.append(
ResidualBlock(out_channels=channel_dims[1])
ResidualBlock(out_channels=channel_dims_per_stage[0])
)
self.intermediate_relu = tf.keras.layers.ReLU()
def call(self, inputs):
if self.downsample:
if self.initial_downsample:
inputs = self.downsample_input(inputs)
else:
inputs = self.conv_input(inputs)
......@@ -502,38 +533,92 @@ def hourglass_104():
"""
return HourglassNetwork(
channel_dims=[128, 256, 256, 384, 384, 384, 512],
input_channel_dims=128,
channel_dims_per_stage=[256, 256, 384, 384, 384, 512],
num_hourglasses=2,
num_stages=5,
blocks_per_stage=[2, 2, 2, 2, 2, 4],
)
def single_stage_hourglass(blocks_per_stage, num_channels, downsample=True):
nc = num_channels
channel_dims = [nc, nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc * 4]
num_stages = len(blocks_per_stage) - 1
channel_dims = channel_dims[:num_stages + 2]
def single_stage_hourglass(input_channel_dims, channel_dims_per_stage,
blocks_per_stage, initial_downsample=True,
stagewise_downsample=True,
encoder_decoder_shortcut=True):
assert len(channel_dims_per_stage) == len(blocks_per_stage)
return HourglassNetwork(
channel_dims=channel_dims,
input_channel_dims=input_channel_dims,
channel_dims_per_stage=channel_dims_per_stage,
num_hourglasses=1,
num_stages=num_stages,
num_stages=len(channel_dims_per_stage) - 1,
blocks_per_stage=blocks_per_stage,
downsample=downsample
initial_downsample=initial_downsample,
stagewise_downsample=stagewise_downsample,
encoder_decoder_shortcut=encoder_decoder_shortcut
)
def hourglass_10(num_channels, downsample=True):
return single_stage_hourglass([1, 1], num_channels, downsample)
def hourglass_10(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[1, 1],
channel_dims_per_stage=[nc * 2, nc * 2])
def hourglass_20(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[1, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3])
def hourglass_32(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[2, 2, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3])
def hourglass_20(num_channels, downsample=True):
return single_stage_hourglass([1, 2, 2], num_channels, downsample)
def hourglass_52(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[2, 2, 2, 2, 2, 4],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4])
def hourglass_32(num_channels, downsample=True):
return single_stage_hourglass([2, 2, 2, 2], num_channels, downsample)
def hourglass_100(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[4, 4, 4, 4, 4, 8],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4])
def hourglass_52(num_channels, downsample=True):
return single_stage_hourglass([2, 2, 2, 2, 2, 4], num_channels, downsample)
def hourglass_20_uniform_size(num_channels):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
blocks_per_stage=[1, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3],
initial_downsample=False,
stagewise_downsample=False)
def hourglass_20_no_shortcut(num_channels):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
blocks_per_stage=[1, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3],
initial_downsample=False,
encoder_decoder_shortcut=False)
......@@ -95,8 +95,8 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
def test_hourglass_feature_extractor(self):
model = hourglass.HourglassNetwork(
num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6],
channel_dims=[4, 6, 8, 10, 12, 14], num_hourglasses=2)
num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6], input_channel_dims=4,
channel_dims_per_stage=[6, 8, 10, 12, 14], num_hourglasses=2)
outputs = model(np.zeros((2, 64, 64, 3), dtype=np.float32))
self.assertEqual(outputs[0].shape, (2, 16, 16, 6))
self.assertEqual(outputs[1].shape, (2, 16, 16, 6))
......@@ -111,33 +111,47 @@ class HourglassDepthTest(tf.test.TestCase):
self.assertEqual(hourglass.hourglass_depth(net), 104)
def test_hourglass_10(self):
net = hourglass.hourglass_10(2, downsample=False)
net = hourglass.hourglass_10(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 10)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_20(self):
net = hourglass.hourglass_20(2, downsample=False)
net = hourglass.hourglass_20(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 20)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_32(self):
net = hourglass.hourglass_32(2, downsample=False)
net = hourglass.hourglass_32(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 32)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_52(self):
net = hourglass.hourglass_52(2, downsample=False)
net = hourglass.hourglass_52(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 52)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_20_uniform_size(self):
net = hourglass.hourglass_20_uniform_size(2)
self.assertEqual(hourglass.hourglass_depth(net), 20)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_100(self):
net = hourglass.hourglass_100(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 100)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
if __name__ == '__main__':
tf.test.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment