Commit 8e3eb8d4 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

clean up

parent 1554a4d7
......@@ -21,14 +21,14 @@ from __future__ import print_function
import tensorflow.compat.v1 as tf
import tf_slim as slim
class BatchNormAndProj():
def __init__(self):
self.batch_norm = None
self.projection = None
# The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000
KEY_NAME = 'key'
VALUE_NAME = 'val'
QUERY_NAME = 'query'
FEATURE_NAME = 'feature'
class ContextProjection(tf.keras.layers.Layer):
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=0.001,
......@@ -119,7 +119,7 @@ def compute_valid_mask(num_valid_elements, num_elements):
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask
def project_features(features, projection_dimension, is_training, freeze_batchnorm, node=None, normalize=True):
def project_features(features, projection_dimension, is_training, node, normalize=True):
"""Projects features to another feature space.
Args:
......@@ -127,7 +127,6 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
num_features].
projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether tunable parameters for batch normalization should be frozen.
node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
......@@ -135,7 +134,6 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
Returns:
A float Tensor of shape [batch, features_size, projection_dimension].
"""
print("Called project")
shape_arr = features.shape
batch_size = shape_arr[0]
feature_size = shape_arr[1]
......@@ -144,21 +142,18 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
projected_features = node(features)
#print(projected_features.shape)
projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension])
if normalize:
projected_features = tf.math.l2_normalize(projected_features, axis=-1)
print("Projected", features.shape, projected_features.shape)
return projected_features
def attention_block(input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask,
is_training, freeze_batchnorm, attention_projections):
is_training, attention_projections):
"""Generic attention block.
Args:
......@@ -175,8 +170,7 @@ def attention_block(input_features, context_features, bottleneck_dimension,
weights = exp(weights / temperature) / sum(exp(weights / temperature))
valid_mask: A boolean Tensor of shape [batch_size, context_size].
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
attention_projections: Contains a dictionary of the projection objects.
Returns:
A float Tensor of shape [batch_size, input_size, output_dimension].
......@@ -184,11 +178,14 @@ def attention_block(input_features, context_features, bottleneck_dimension,
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["query"], normalize=True)
input_features, bottleneck_dimension, is_training,
attention_projections[QUERY_NAME], normalize=True)
keys = project_features(
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["key"], normalize=True)
context_features, bottleneck_dimension, is_training,
attention_projections[KEY_NAME], normalize=True)
values = project_features(
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["val"], normalize=True)
context_features, bottleneck_dimension, is_training,
attention_projections[VALUE_NAME], normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
......@@ -198,13 +195,15 @@ def attention_block(input_features, context_features, bottleneck_dimension,
features = tf.matmul(weights, values)
output_features = project_features(
features, output_dimension, is_training, freeze_batchnorm, node=attention_projections["feature"], normalize=False)
features, output_dimension, is_training,
attention_projections[FEATURE_NAME], normalize=False)
return output_features
def compute_box_context_attention(box_features, context_features,
valid_context_size, bottleneck_dimension,
attention_temperature, is_training, freeze_batchnorm, attention_projections):
attention_temperature, is_training,
freeze_batchnorm, attention_projections):
"""Computes the attention feature from the context given a batch of box.
Args:
......@@ -220,7 +219,7 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights.
freeze_batchnorm: A boolean indicating whether to freeze batch normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns:
......@@ -232,7 +231,7 @@ def compute_box_context_attention(box_features, context_features,
channels = box_features.shape[-1]
if 'feature' not in attention_projections:
attention_projections['feature'] = ContextProjection(channels, freeze_batchnorm)
attention_projections[FEATURE_NAME] = ContextProjection(channels, freeze_batchnorm)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
......@@ -241,7 +240,7 @@ def compute_box_context_attention(box_features, context_features,
output_features = attention_block(box_features, context_features,
bottleneck_dimension, channels,
attention_temperature, valid_mask,
is_training, freeze_batchnorm, attention_projections)
is_training, attention_projections)
# Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
......
......@@ -80,7 +80,6 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projected_features = context_rcnn_lib.project_features(
features,
projection_dimension,
freeze_batchnorm=False,
is_training=is_training,
normalize=normalize,
node=context_rcnn_lib.ContextProjection(projection_dimension, False))
......@@ -101,12 +100,15 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
context_features = tf.ones([2, 2, 3], tf.float32)
valid_mask = tf.constant([[True, True], [False, False]], tf.bool)
is_training = False
projection_layers = {"key": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "val": context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
"query": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "feature": context_rcnn_lib.ContextProjection(output_dimension, False)}
projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
#Add in the feature layer because this is further down the pipeline and it isn't automatically injected.
projection_layers['feature'] = context_rcnn_lib.ContextProjection(output_dimension, False)
output_features = context_rcnn_lib.attention_block(
input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, is_training, False, projection_layers)
output_dimension, attention_temperature, valid_mask, is_training, projection_layers)
# Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 3, output_dimension])
......@@ -118,8 +120,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
valid_context_size = tf.constant((2, 3), tf.int32)
bottleneck_dimension = 10
attention_temperature = 1
projection_layers = {"key": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "val": context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
"query": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "feature": context_rcnn_lib.ContextProjection(box_features.shape[-1], False)}
projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size,
bottleneck_dimension, attention_temperature, is_training,
......
......@@ -75,8 +75,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
return_raw_detections_during_predict=False,
output_final_box_features=False,
attention_bottleneck_dimension=None,
attention_temperature=None,
attention_projections=None):
attention_temperature=None):
"""ContextRCNNMetaArch Constructor.
Args:
......@@ -273,9 +272,12 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
is_training=is_training,
freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {"key": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm),
"val": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm),
"query": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm)}
self._attention_projections = {'key': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm),
'val': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm),
'query': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm)}
@staticmethod
def get_side_inputs(features):
......@@ -334,13 +336,11 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
features_to_crop, proposal_boxes_normalized,
[self._initial_crop_size, self._initial_crop_size])
print(self._attention_projections)
attention_features = self._context_feature_extract_fn(
box_features=box_features,
context_features=context_features,
valid_context_size=valid_context_size,
attention_projections=self._attention_projections)
print(self._attention_projections)
# Adds box features with attention features.
box_features += attention_features
......
......@@ -515,7 +515,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.valid_context_size: valid_context_size
}
print("sep")
side_inputs = model.get_side_inputs(features)
prediction_dict = model.predict(preprocessed_inputs, true_image_shapes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment