Commit ee4b011b authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Support for more hourglass network configurations.

PiperOrigin-RevId: 333576935
parent 056859fa
...@@ -30,7 +30,8 @@ class CenterNetHourglassFeatureExtractorTest(test_case.TestCase): ...@@ -30,7 +30,8 @@ class CenterNetHourglassFeatureExtractorTest(test_case.TestCase):
net = hourglass_network.HourglassNetwork( net = hourglass_network.HourglassNetwork(
num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6], num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6],
channel_dims=[4, 6, 8, 10, 12, 14], num_hourglasses=2) input_channel_dims=4, channel_dims_per_stage=[6, 8, 10, 12, 14],
num_hourglasses=2)
model = hourglass.CenterNetHourglassFeatureExtractor(net) model = hourglass.CenterNetHourglassFeatureExtractor(net)
def graph_fn(): def graph_fn():
......
...@@ -193,9 +193,8 @@ class InputConvBlock(tf.keras.layers.Layer): ...@@ -193,9 +193,8 @@ class InputConvBlock(tf.keras.layers.Layer):
super(InputConvBlock, self).__init__() super(InputConvBlock, self).__init__()
# TODO(vighneshb) explore if 3x3 works here.
self.conv_block = ConvolutionalBlock( self.conv_block = ConvolutionalBlock(
kernel_size=7, out_channels=out_channels_initial_conv, stride=1, kernel_size=3, out_channels=out_channels_initial_conv, stride=1,
padding='valid') padding='valid')
self.residual_block = ResidualBlock( self.residual_block = ResidualBlock(
out_channels=out_channels_residual_block, stride=1, skip_conv=True) out_channels=out_channels_residual_block, stride=1, skip_conv=True)
...@@ -205,7 +204,8 @@ class InputConvBlock(tf.keras.layers.Layer): ...@@ -205,7 +204,8 @@ class InputConvBlock(tf.keras.layers.Layer):
def _make_repeated_residual_blocks(out_channels, num_blocks, def _make_repeated_residual_blocks(out_channels, num_blocks,
initial_stride=1, residual_channels=None): initial_stride=1, residual_channels=None,
initial_skip_conv=False):
"""Stack Residual blocks one after the other. """Stack Residual blocks one after the other.
Args: Args:
...@@ -214,6 +214,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks, ...@@ -214,6 +214,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
initial_stride: int, the stride of the initial residual block. initial_stride: int, the stride of the initial residual block.
residual_channels: int, the desired number of output channels in the residual_channels: int, the desired number of output channels in the
intermediate residual blocks. If not specifed, we use out_channels. intermediate residual blocks. If not specifed, we use out_channels.
initial_skip_conv: bool, if set, the first residual block uses a skip
convolution. This is useful when the number of channels in the input
are not the same as residual_channels.
Returns: Returns:
blocks: A list of residual blocks to be applied in sequence. blocks: A list of residual blocks to be applied in sequence.
...@@ -234,6 +237,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks, ...@@ -234,6 +237,9 @@ def _make_repeated_residual_blocks(out_channels, num_blocks,
# skip connection and are forced to use a conv for the skip connection. # skip connection and are forced to use a conv for the skip connection.
skip_conv = stride > 1 skip_conv = stride > 1
if i == 0 and initial_skip_conv:
skip_conv = True
blocks.append( blocks.append(
ResidualBlock(out_channels=residual_channels, stride=stride, ResidualBlock(out_channels=residual_channels, stride=stride,
skip_conv=skip_conv) skip_conv=skip_conv)
...@@ -267,7 +273,8 @@ def _apply_blocks(inputs, blocks): ...@@ -267,7 +273,8 @@ def _apply_blocks(inputs, blocks):
class EncoderDecoderBlock(tf.keras.layers.Layer): class EncoderDecoderBlock(tf.keras.layers.Layer):
"""An encoder-decoder block which recursively defines the hourglass network.""" """An encoder-decoder block which recursively defines the hourglass network."""
def __init__(self, num_stages, channel_dims, blocks_per_stage): def __init__(self, num_stages, channel_dims, blocks_per_stage,
stagewise_downsample=True, encoder_decoder_shortcut=True):
"""Initializes the encoder-decoder block. """Initializes the encoder-decoder block.
Args: Args:
...@@ -282,6 +289,10 @@ class EncoderDecoderBlock(tf.keras.layers.Layer): ...@@ -282,6 +289,10 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
blocks_per_stage: int list, number of residual blocks to use at each blocks_per_stage: int list, number of residual blocks to use at each
stage. `blocks_per_stage[0]` defines the number of blocks at the stage. `blocks_per_stage[0]` defines the number of blocks at the
current stage and `blocks_per_stage[1:]` is used at further stages. current stage and `blocks_per_stage[1:]` is used at further stages.
stagewise_downsample: bool, whether or not to downsample before passing
inputs to the next stage.
encoder_decoder_shortcut: bool, whether or not to use shortcut
connections between encoder and decoder.
""" """
super(EncoderDecoderBlock, self).__init__() super(EncoderDecoderBlock, self).__init__()
...@@ -289,17 +300,26 @@ class EncoderDecoderBlock(tf.keras.layers.Layer): ...@@ -289,17 +300,26 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
out_channels = channel_dims[0] out_channels = channel_dims[0]
out_channels_downsampled = channel_dims[1] out_channels_downsampled = channel_dims[1]
self.encoder_block1 = _make_repeated_residual_blocks( self.encoder_decoder_shortcut = encoder_decoder_shortcut
out_channels=out_channels, num_blocks=blocks_per_stage[0],
initial_stride=1) if encoder_decoder_shortcut:
self.merge_features = tf.keras.layers.Add()
self.encoder_block1 = _make_repeated_residual_blocks(
out_channels=out_channels, num_blocks=blocks_per_stage[0],
initial_stride=1)
initial_stride = 2 if stagewise_downsample else 1
self.encoder_block2 = _make_repeated_residual_blocks( self.encoder_block2 = _make_repeated_residual_blocks(
out_channels=out_channels_downsampled, out_channels=out_channels_downsampled,
num_blocks=blocks_per_stage[0], initial_stride=2) num_blocks=blocks_per_stage[0], initial_stride=initial_stride,
initial_skip_conv=out_channels != out_channels_downsampled)
if num_stages > 1: if num_stages > 1:
self.inner_block = [ self.inner_block = [
EncoderDecoderBlock(num_stages - 1, channel_dims[1:], EncoderDecoderBlock(num_stages - 1, channel_dims[1:],
blocks_per_stage[1:]) blocks_per_stage[1:],
stagewise_downsample=stagewise_downsample,
encoder_decoder_shortcut=encoder_decoder_shortcut)
] ]
else: else:
self.inner_block = _make_repeated_residual_blocks( self.inner_block = _make_repeated_residual_blocks(
...@@ -309,13 +329,13 @@ class EncoderDecoderBlock(tf.keras.layers.Layer): ...@@ -309,13 +329,13 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
self.decoder_block = _make_repeated_residual_blocks( self.decoder_block = _make_repeated_residual_blocks(
residual_channels=out_channels_downsampled, residual_channels=out_channels_downsampled,
out_channels=out_channels, num_blocks=blocks_per_stage[0]) out_channels=out_channels, num_blocks=blocks_per_stage[0])
self.upsample = tf.keras.layers.UpSampling2D(2)
self.merge_features = tf.keras.layers.Add() self.upsample = tf.keras.layers.UpSampling2D(initial_stride)
def call(self, inputs): def call(self, inputs):
encoded_outputs = _apply_blocks(inputs, self.encoder_block1) if self.encoder_decoder_shortcut:
encoded_outputs = _apply_blocks(inputs, self.encoder_block1)
encoded_downsampled_outputs = _apply_blocks(inputs, self.encoder_block2) encoded_downsampled_outputs = _apply_blocks(inputs, self.encoder_block2)
inner_block_outputs = _apply_blocks( inner_block_outputs = _apply_blocks(
encoded_downsampled_outputs, self.inner_block) encoded_downsampled_outputs, self.inner_block)
...@@ -323,45 +343,53 @@ class EncoderDecoderBlock(tf.keras.layers.Layer): ...@@ -323,45 +343,53 @@ class EncoderDecoderBlock(tf.keras.layers.Layer):
decoded_outputs = _apply_blocks(inner_block_outputs, self.decoder_block) decoded_outputs = _apply_blocks(inner_block_outputs, self.decoder_block)
upsampled_outputs = self.upsample(decoded_outputs) upsampled_outputs = self.upsample(decoded_outputs)
return self.merge_features([encoded_outputs, upsampled_outputs]) if self.encoder_decoder_shortcut:
return self.merge_features([encoded_outputs, upsampled_outputs])
else:
return upsampled_outputs
class HourglassNetwork(tf.keras.Model): class HourglassNetwork(tf.keras.Model):
"""The hourglass network.""" """The hourglass network."""
def __init__(self, num_stages, channel_dims, blocks_per_stage, def __init__(self, num_stages, input_channel_dims, channel_dims_per_stage,
num_hourglasses, downsample=True): blocks_per_stage, num_hourglasses, initial_downsample=True,
stagewise_downsample=True, encoder_decoder_shortcut=True):
"""Intializes the feature extractor. """Intializes the feature extractor.
Args: Args:
num_stages: int, Number of stages in the network. At each stage we have 2 num_stages: int, Number of stages in the network. At each stage we have 2
encoder and 1 decoder blocks. The second encoder block downsamples the encoder and 1 decoder blocks. The second encoder block downsamples the
input. input.
channel_dims: int list, the output channel dimensions of stages in input_channel_dims: int, the number of channels in the input conv blocks.
the network. `channel_dims[0]` and `channel_dims[1]` are used to define channel_dims_per_stage: int list, the output channel dimensions of each
the initial downsampling block. `channel_dims[1:]` is used to define stage in the hourglass network.
the hourglass network(s) which follow(s).
blocks_per_stage: int list, number of residual blocks to use at each blocks_per_stage: int list, number of residual blocks to use at each
stage in the hourglass network stage in the hourglass network
num_hourglasses: int, number of hourglas networks to stack num_hourglasses: int, number of hourglas networks to stack
sequentially. sequentially.
downsample: bool, if set, downsamples the input by a factor of 4 before initial_downsample: bool, if set, downsamples the input by a factor of 4
applying the rest of the network. before applying the rest of the network. Downsampling is done with a 7x7
convolution kernel, otherwise a 3x3 kernel is used.
stagewise_downsample: bool, whether or not to downsample before passing
inputs to the next stage.
encoder_decoder_shortcut: bool, whether or not to use shortcut
connections between encoder and decoder.
""" """
super(HourglassNetwork, self).__init__() super(HourglassNetwork, self).__init__()
self.num_hourglasses = num_hourglasses self.num_hourglasses = num_hourglasses
self.downsample = downsample self.initial_downsample = initial_downsample
if downsample: if initial_downsample:
self.downsample_input = InputDownsampleBlock( self.downsample_input = InputDownsampleBlock(
out_channels_initial_conv=channel_dims[0], out_channels_initial_conv=input_channel_dims,
out_channels_residual_block=channel_dims[1] out_channels_residual_block=channel_dims_per_stage[0]
) )
else: else:
self.conv_input = InputConvBlock( self.conv_input = InputConvBlock(
out_channels_initial_conv=channel_dims[0], out_channels_initial_conv=input_channel_dims,
out_channels_residual_block=channel_dims[1] out_channels_residual_block=channel_dims_per_stage[0]
) )
self.hourglass_network = [] self.hourglass_network = []
...@@ -369,11 +397,14 @@ class HourglassNetwork(tf.keras.Model): ...@@ -369,11 +397,14 @@ class HourglassNetwork(tf.keras.Model):
for _ in range(self.num_hourglasses): for _ in range(self.num_hourglasses):
self.hourglass_network.append( self.hourglass_network.append(
EncoderDecoderBlock( EncoderDecoderBlock(
num_stages=num_stages, channel_dims=channel_dims[1:], num_stages=num_stages, channel_dims=channel_dims_per_stage,
blocks_per_stage=blocks_per_stage) blocks_per_stage=blocks_per_stage,
stagewise_downsample=stagewise_downsample,
encoder_decoder_shortcut=encoder_decoder_shortcut)
) )
self.output_conv.append( self.output_conv.append(
ConvolutionalBlock(kernel_size=3, out_channels=channel_dims[1]) ConvolutionalBlock(kernel_size=3,
out_channels=channel_dims_per_stage[0])
) )
self.intermediate_conv1 = [] self.intermediate_conv1 = []
...@@ -383,21 +414,21 @@ class HourglassNetwork(tf.keras.Model): ...@@ -383,21 +414,21 @@ class HourglassNetwork(tf.keras.Model):
for _ in range(self.num_hourglasses - 1): for _ in range(self.num_hourglasses - 1):
self.intermediate_conv1.append( self.intermediate_conv1.append(
ConvolutionalBlock( ConvolutionalBlock(
kernel_size=1, out_channels=channel_dims[1], relu=False) kernel_size=1, out_channels=channel_dims_per_stage[0], relu=False)
) )
self.intermediate_conv2.append( self.intermediate_conv2.append(
ConvolutionalBlock( ConvolutionalBlock(
kernel_size=1, out_channels=channel_dims[1], relu=False) kernel_size=1, out_channels=channel_dims_per_stage[0], relu=False)
) )
self.intermediate_residual.append( self.intermediate_residual.append(
ResidualBlock(out_channels=channel_dims[1]) ResidualBlock(out_channels=channel_dims_per_stage[0])
) )
self.intermediate_relu = tf.keras.layers.ReLU() self.intermediate_relu = tf.keras.layers.ReLU()
def call(self, inputs): def call(self, inputs):
if self.downsample: if self.initial_downsample:
inputs = self.downsample_input(inputs) inputs = self.downsample_input(inputs)
else: else:
inputs = self.conv_input(inputs) inputs = self.conv_input(inputs)
...@@ -502,38 +533,92 @@ def hourglass_104(): ...@@ -502,38 +533,92 @@ def hourglass_104():
""" """
return HourglassNetwork( return HourglassNetwork(
channel_dims=[128, 256, 256, 384, 384, 384, 512], input_channel_dims=128,
channel_dims_per_stage=[256, 256, 384, 384, 384, 512],
num_hourglasses=2, num_hourglasses=2,
num_stages=5, num_stages=5,
blocks_per_stage=[2, 2, 2, 2, 2, 4], blocks_per_stage=[2, 2, 2, 2, 2, 4],
) )
def single_stage_hourglass(blocks_per_stage, num_channels, downsample=True): def single_stage_hourglass(input_channel_dims, channel_dims_per_stage,
nc = num_channels blocks_per_stage, initial_downsample=True,
channel_dims = [nc, nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc * 4] stagewise_downsample=True,
num_stages = len(blocks_per_stage) - 1 encoder_decoder_shortcut=True):
channel_dims = channel_dims[:num_stages + 2] assert len(channel_dims_per_stage) == len(blocks_per_stage)
return HourglassNetwork( return HourglassNetwork(
channel_dims=channel_dims, input_channel_dims=input_channel_dims,
channel_dims_per_stage=channel_dims_per_stage,
num_hourglasses=1, num_hourglasses=1,
num_stages=num_stages, num_stages=len(channel_dims_per_stage) - 1,
blocks_per_stage=blocks_per_stage, blocks_per_stage=blocks_per_stage,
downsample=downsample initial_downsample=initial_downsample,
stagewise_downsample=stagewise_downsample,
encoder_decoder_shortcut=encoder_decoder_shortcut
) )
def hourglass_10(num_channels, downsample=True): def hourglass_10(num_channels, initial_downsample=True):
return single_stage_hourglass([1, 1], num_channels, downsample) nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[1, 1],
channel_dims_per_stage=[nc * 2, nc * 2])
def hourglass_20(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[1, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3])
def hourglass_32(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[2, 2, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3])
def hourglass_20(num_channels, downsample=True):
return single_stage_hourglass([1, 2, 2], num_channels, downsample)
def hourglass_52(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[2, 2, 2, 2, 2, 4],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4])
def hourglass_32(num_channels, downsample=True):
return single_stage_hourglass([2, 2, 2, 2], num_channels, downsample)
def hourglass_100(num_channels, initial_downsample=True):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
initial_downsample=initial_downsample,
blocks_per_stage=[4, 4, 4, 4, 4, 8],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3, nc * 3, nc * 3, nc*4])
def hourglass_52(num_channels, downsample=True):
return single_stage_hourglass([2, 2, 2, 2, 2, 4], num_channels, downsample) def hourglass_20_uniform_size(num_channels):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
blocks_per_stage=[1, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3],
initial_downsample=False,
stagewise_downsample=False)
def hourglass_20_no_shortcut(num_channels):
nc = num_channels
return single_stage_hourglass(
input_channel_dims=nc,
blocks_per_stage=[1, 2, 2],
channel_dims_per_stage=[nc * 2, nc * 2, nc * 3],
initial_downsample=False,
encoder_decoder_shortcut=False)
...@@ -95,8 +95,8 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,8 +95,8 @@ class HourglassFeatureExtractorTest(tf.test.TestCase, parameterized.TestCase):
def test_hourglass_feature_extractor(self): def test_hourglass_feature_extractor(self):
model = hourglass.HourglassNetwork( model = hourglass.HourglassNetwork(
num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6], num_stages=4, blocks_per_stage=[2, 3, 4, 5, 6], input_channel_dims=4,
channel_dims=[4, 6, 8, 10, 12, 14], num_hourglasses=2) channel_dims_per_stage=[6, 8, 10, 12, 14], num_hourglasses=2)
outputs = model(np.zeros((2, 64, 64, 3), dtype=np.float32)) outputs = model(np.zeros((2, 64, 64, 3), dtype=np.float32))
self.assertEqual(outputs[0].shape, (2, 16, 16, 6)) self.assertEqual(outputs[0].shape, (2, 16, 16, 6))
self.assertEqual(outputs[1].shape, (2, 16, 16, 6)) self.assertEqual(outputs[1].shape, (2, 16, 16, 6))
...@@ -111,33 +111,47 @@ class HourglassDepthTest(tf.test.TestCase): ...@@ -111,33 +111,47 @@ class HourglassDepthTest(tf.test.TestCase):
self.assertEqual(hourglass.hourglass_depth(net), 104) self.assertEqual(hourglass.hourglass_depth(net), 104)
def test_hourglass_10(self): def test_hourglass_10(self):
net = hourglass.hourglass_10(2, downsample=False) net = hourglass.hourglass_10(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 10) self.assertEqual(hourglass.hourglass_depth(net), 10)
outputs = net(tf.zeros((2, 32, 32, 3))) outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4)) self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_20(self): def test_hourglass_20(self):
net = hourglass.hourglass_20(2, downsample=False) net = hourglass.hourglass_20(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 20) self.assertEqual(hourglass.hourglass_depth(net), 20)
outputs = net(tf.zeros((2, 32, 32, 3))) outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4)) self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_32(self): def test_hourglass_32(self):
net = hourglass.hourglass_32(2, downsample=False) net = hourglass.hourglass_32(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 32) self.assertEqual(hourglass.hourglass_depth(net), 32)
outputs = net(tf.zeros((2, 32, 32, 3))) outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4)) self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_52(self): def test_hourglass_52(self):
net = hourglass.hourglass_52(2, downsample=False) net = hourglass.hourglass_52(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 52) self.assertEqual(hourglass.hourglass_depth(net), 52)
outputs = net(tf.zeros((2, 32, 32, 3))) outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4)) self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_20_uniform_size(self):
net = hourglass.hourglass_20_uniform_size(2)
self.assertEqual(hourglass.hourglass_depth(net), 20)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
def test_hourglass_100(self):
net = hourglass.hourglass_100(2, initial_downsample=False)
self.assertEqual(hourglass.hourglass_depth(net), 100)
outputs = net(tf.zeros((2, 32, 32, 3)))
self.assertEqual(outputs[0].shape, (2, 32, 32, 4))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment