Commit 1554a4d7 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

fix context tf2 support

parent 33d05bec
......@@ -29,6 +29,26 @@ class BatchNormAndProj():
# The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000
class ContextProjection(tf.keras.layers.Layer):
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=0.001,
center=True,
scale=True,
momentum=0.97,
trainable=(not freeze_batchnorm))
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
super(ContextProjection,self).__init__(**kwargs)
def build(self, input_shape):
self.batch_norm.build(input_shape)
self.projection.build(input_shape)
def call(self, input):
return self.projection(self.batch_norm(input))
def filter_weight_value(weights, values, valid_mask):
"""Filters weights and values based on valid_mask.
......@@ -115,33 +135,24 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
Returns:
A float Tensor of shape [batch, features_size, projection_dimension].
"""
if node is None:
node = {}
if 'batch_norm' not in node:
node['batch_norm'] = tf.keras.layers.BatchNormalization(epsilon=0.001, center=True, scale=True, momentum=0.97, trainable=(not freeze_batchnorm))
if 'projection' not in node:
print("Creating new projection")
node['projection'] = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
print("Called project")
shape_arr = features.shape
batch_size = shape_arr[0]
feature_size = shape_arr[1]
num_features = shape_arr[2]
features = tf.reshape(features, [-1, num_features])
batch_norm_features = node['batch_norm'](features)
projected_features = node['projection'](batch_norm_features, training=is_training)
projected_features = node(features)
print(projected_features.shape)
#print(projected_features.shape)
projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension])
if normalize:
projected_features = tf.math.l2_normalize(projected_features, axis=-1)
print("Projected", features.shape, projected_features.shape)
return projected_features
......@@ -179,7 +190,6 @@ def attention_block(input_features, context_features, bottleneck_dimension,
values = project_features(
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["val"], normalize=True)
print(attention_projections['query'])
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask)
......@@ -220,6 +230,10 @@ def compute_box_context_attention(box_features, context_features,
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1]
if 'feature' not in attention_projections:
attention_projections['feature'] = ContextProjection(channels, freeze_batchnorm)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3])
......
......@@ -82,7 +82,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projection_dimension,
freeze_batchnorm=False,
is_training=is_training,
normalize=normalize)
normalize=normalize,
node=context_rcnn_lib.ContextProjection(projection_dimension, False))
# Makes sure the shape is correct.
self.assertAllEqual(projected_features.shape, [2, 3, projection_dimension])
......@@ -100,7 +101,9 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
context_features = tf.ones([2, 2, 3], tf.float32)
valid_mask = tf.constant([[True, True], [False, False]], tf.bool)
is_training = False
projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}}
projection_layers = {"key": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "val": context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
"query": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "feature": context_rcnn_lib.ContextProjection(output_dimension, False)}
output_features = context_rcnn_lib.attention_block(
input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, is_training, False, projection_layers)
......@@ -115,7 +118,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
valid_context_size = tf.constant((2, 3), tf.int32)
bottleneck_dimension = 10
attention_temperature = 1
projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}}
projection_layers = {"key": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "val": context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
"query": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "feature": context_rcnn_lib.ContextProjection(box_features.shape[-1], False)}
attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size,
bottleneck_dimension, attention_temperature, is_training,
......
......@@ -273,10 +273,9 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
is_training=is_training,
freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {"key": {},
"val": {},
"query": {},
"feature": {}}
self._attention_projections = {"key": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm),
"val": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm),
"query": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm)}
@staticmethod
def get_side_inputs(features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment