Commit aabe1231 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Context RCNN support for Tensorflow 2.x

parent e3f88e11
......@@ -23,13 +23,13 @@ _NEGATIVE_PADDING_VALUE = -100000
class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
def __init__(self, projection_dimension, **kwargs):
self.batch_norm = freezable_batch_norm.FreezableBatchNorm(
epsilon=0.001,
center=True,
scale=True,
momentum=0.97,
trainable=(not freeze_batchnorm))
trainable=True)
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
......@@ -45,33 +45,29 @@ class ContextProjection(tf.keras.layers.Layer):
class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature,
freeze_batchnorm, output_dimension=None,
output_dimension=None,
is_training=False, name='AttentionBlock', **kwargs):
self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._key_proj = ContextProjection(bottleneck_dimension)
self._val_proj = ContextProjection(bottleneck_dimension)
self._query_proj = ContextProjection(bottleneck_dimension)
self._feature_proj = None
self._attention_temperature = attention_temperature
self._freeze_batchnorm = freeze_batchnorm
self._bottleneck_dimension = bottleneck_dimension
self._is_training = is_training
self._output_dimension = output_dimension
if self._output_dimension:
self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm)
self._feature_proj = ContextProjection(self._output_dimension)
super(AttentionBlock, self).__init__(name=name, **kwargs)
def build(self, input_shapes):
if not self._feature_proj:
self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm)
self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention."""
_, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1]
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
......
......@@ -80,7 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
features,
projection_dimension,
is_training,
context_rcnn_lib.ContextProjection(projection_dimension, False),
context_rcnn_lib.ContextProjection(projection_dimension),
normalize=normalize)
# Makes sure the shape is correct.
......@@ -97,10 +97,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
attention_temperature):
input_features = tf.ones([2, 8, 3, 3, 3], tf.float32)
context_features = tf.ones([2, 20, 10], tf.float32)
is_training = False
attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension,
attention_temperature,
freeze_batchnorm=False,
attention_temperature,
output_dimension=output_dimension,
is_training=False)
valid_context_size = tf.random_uniform((2,),
......
......@@ -275,7 +275,6 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
self._context_feature_extract_fn = context_rcnn_lib_tf2.AttentionBlock(
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
freeze_batchnorm=freeze_batchnorm,
is_training=is_training)
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment