Commit aabe1231 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Context RCNN support for Tensorflow 2.x

parent e3f88e11
...@@ -23,13 +23,13 @@ _NEGATIVE_PADDING_VALUE = -100000 ...@@ -23,13 +23,13 @@ _NEGATIVE_PADDING_VALUE = -100000
class ContextProjection(tf.keras.layers.Layer): class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection.""" """Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs): def __init__(self, projection_dimension, **kwargs):
self.batch_norm = freezable_batch_norm.FreezableBatchNorm( self.batch_norm = freezable_batch_norm.FreezableBatchNorm(
epsilon=0.001, epsilon=0.001,
center=True, center=True,
scale=True, scale=True,
momentum=0.97, momentum=0.97,
trainable=(not freeze_batchnorm)) trainable=True)
self.projection = tf.keras.layers.Dense(units=projection_dimension, self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6, activation=tf.nn.relu6,
use_bias=True) use_bias=True)
...@@ -45,33 +45,29 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -45,33 +45,29 @@ class ContextProjection(tf.keras.layers.Layer):
class AttentionBlock(tf.keras.layers.Layer): class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention.""" """Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature, def __init__(self, bottleneck_dimension, attention_temperature,
freeze_batchnorm, output_dimension=None, output_dimension=None,
is_training=False, name='AttentionBlock', **kwargs): is_training=False, name='AttentionBlock', **kwargs):
self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._key_proj = ContextProjection(bottleneck_dimension)
self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._val_proj = ContextProjection(bottleneck_dimension)
self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._query_proj = ContextProjection(bottleneck_dimension)
self._feature_proj = None self._feature_proj = None
self._attention_temperature = attention_temperature self._attention_temperature = attention_temperature
self._freeze_batchnorm = freeze_batchnorm
self._bottleneck_dimension = bottleneck_dimension self._bottleneck_dimension = bottleneck_dimension
self._is_training = is_training self._is_training = is_training
self._output_dimension = output_dimension self._output_dimension = output_dimension
if self._output_dimension: if self._output_dimension:
self._feature_proj = ContextProjection(self._output_dimension, self._feature_proj = ContextProjection(self._output_dimension)
self._freeze_batchnorm)
super(AttentionBlock, self).__init__(name=name, **kwargs) super(AttentionBlock, self).__init__(name=name, **kwargs)
def build(self, input_shapes): def build(self, input_shapes):
if not self._feature_proj: if not self._feature_proj:
self._output_dimension = input_shapes[-1] self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension, self._feature_proj = ContextProjection(self._output_dimension)
self._freeze_batchnorm)
def call(self, box_features, context_features, valid_context_size): def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention.""" """Handles a call by performing attention."""
_, context_size, _ = context_features.shape _, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size) valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1]
# Average pools over height and width dimension so that the shape of # Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels]. # box_features becomes [batch_size, max_num_proposals, channels].
......
...@@ -80,7 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase): ...@@ -80,7 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
features, features,
projection_dimension, projection_dimension,
is_training, is_training,
context_rcnn_lib.ContextProjection(projection_dimension, False), context_rcnn_lib.ContextProjection(projection_dimension),
normalize=normalize) normalize=normalize)
# Makes sure the shape is correct. # Makes sure the shape is correct.
...@@ -97,10 +97,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase): ...@@ -97,10 +97,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
attention_temperature): attention_temperature):
input_features = tf.ones([2, 8, 3, 3, 3], tf.float32) input_features = tf.ones([2, 8, 3, 3, 3], tf.float32)
context_features = tf.ones([2, 20, 10], tf.float32) context_features = tf.ones([2, 20, 10], tf.float32)
is_training = False
attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension,
attention_temperature, attention_temperature,
freeze_batchnorm=False,
output_dimension=output_dimension, output_dimension=output_dimension,
is_training=False) is_training=False)
valid_context_size = tf.random_uniform((2,), valid_context_size = tf.random_uniform((2,),
......
...@@ -275,7 +275,6 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -275,7 +275,6 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
self._context_feature_extract_fn = context_rcnn_lib_tf2.AttentionBlock( self._context_feature_extract_fn = context_rcnn_lib_tf2.AttentionBlock(
bottleneck_dimension=attention_bottleneck_dimension, bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature, attention_temperature=attention_temperature,
freeze_batchnorm=freeze_batchnorm,
is_training=is_training) is_training=is_training)
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment