Commit 8e3eb8d4 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

clean up

parent 1554a4d7
...@@ -21,14 +21,14 @@ from __future__ import print_function ...@@ -21,14 +21,14 @@ from __future__ import print_function
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim import tf_slim as slim
class BatchNormAndProj():
def __init__(self):
self.batch_norm = None
self.projection = None
# The negative value used in padding the invalid weights. # The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000 _NEGATIVE_PADDING_VALUE = -100000
KEY_NAME = 'key'
VALUE_NAME = 'val'
QUERY_NAME = 'query'
FEATURE_NAME = 'feature'
class ContextProjection(tf.keras.layers.Layer): class ContextProjection(tf.keras.layers.Layer):
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs): def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=0.001, self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=0.001,
...@@ -119,7 +119,7 @@ def compute_valid_mask(num_valid_elements, num_elements): ...@@ -119,7 +119,7 @@ def compute_valid_mask(num_valid_elements, num_elements):
valid_mask = tf.less(batch_element_idxs, num_valid_elements) valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask return valid_mask
def project_features(features, projection_dimension, is_training, freeze_batchnorm, node=None, normalize=True): def project_features(features, projection_dimension, is_training, node, normalize=True):
"""Projects features to another feature space. """Projects features to another feature space.
Args: Args:
...@@ -127,7 +127,6 @@ def project_features(features, projection_dimension, is_training, freeze_batchno ...@@ -127,7 +127,6 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
num_features]. num_features].
projection_dimension: A int32 Tensor. projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether tunable parameters for batch normalization should be frozen.
node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features) node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2 normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension. normalized on the last dimension.
...@@ -135,7 +134,6 @@ def project_features(features, projection_dimension, is_training, freeze_batchno ...@@ -135,7 +134,6 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
Returns: Returns:
A float Tensor of shape [batch, features_size, projection_dimension]. A float Tensor of shape [batch, features_size, projection_dimension].
""" """
print("Called project")
shape_arr = features.shape shape_arr = features.shape
batch_size = shape_arr[0] batch_size = shape_arr[0]
feature_size = shape_arr[1] feature_size = shape_arr[1]
...@@ -144,21 +142,18 @@ def project_features(features, projection_dimension, is_training, freeze_batchno ...@@ -144,21 +142,18 @@ def project_features(features, projection_dimension, is_training, freeze_batchno
projected_features = node(features) projected_features = node(features)
#print(projected_features.shape)
projected_features = tf.reshape(projected_features, projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension]) [batch_size, -1, projection_dimension])
if normalize: if normalize:
projected_features = tf.math.l2_normalize(projected_features, axis=-1) projected_features = tf.math.l2_normalize(projected_features, axis=-1)
print("Projected", features.shape, projected_features.shape)
return projected_features return projected_features
def attention_block(input_features, context_features, bottleneck_dimension, def attention_block(input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, output_dimension, attention_temperature, valid_mask,
is_training, freeze_batchnorm, attention_projections): is_training, attention_projections):
"""Generic attention block. """Generic attention block.
Args: Args:
...@@ -175,8 +170,7 @@ def attention_block(input_features, context_features, bottleneck_dimension, ...@@ -175,8 +170,7 @@ def attention_block(input_features, context_features, bottleneck_dimension,
weights = exp(weights / temperature) / sum(exp(weights / temperature)) weights = exp(weights / temperature) / sum(exp(weights / temperature))
valid_mask: A boolean Tensor of shape [batch_size, context_size]. valid_mask: A boolean Tensor of shape [batch_size, context_size].
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights. attention_projections: Contains a dictionary of the projection objects.
attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns: Returns:
A float Tensor of shape [batch_size, input_size, output_dimension]. A float Tensor of shape [batch_size, input_size, output_dimension].
...@@ -184,11 +178,14 @@ def attention_block(input_features, context_features, bottleneck_dimension, ...@@ -184,11 +178,14 @@ def attention_block(input_features, context_features, bottleneck_dimension,
with tf.variable_scope("AttentionBlock"): with tf.variable_scope("AttentionBlock"):
queries = project_features( queries = project_features(
input_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["query"], normalize=True) input_features, bottleneck_dimension, is_training,
attention_projections[QUERY_NAME], normalize=True)
keys = project_features( keys = project_features(
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["key"], normalize=True) context_features, bottleneck_dimension, is_training,
attention_projections[KEY_NAME], normalize=True)
values = project_features( values = project_features(
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["val"], normalize=True) context_features, bottleneck_dimension, is_training,
attention_projections[VALUE_NAME], normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True) weights = tf.matmul(queries, keys, transpose_b=True)
...@@ -198,13 +195,15 @@ def attention_block(input_features, context_features, bottleneck_dimension, ...@@ -198,13 +195,15 @@ def attention_block(input_features, context_features, bottleneck_dimension,
features = tf.matmul(weights, values) features = tf.matmul(weights, values)
output_features = project_features( output_features = project_features(
features, output_dimension, is_training, freeze_batchnorm, node=attention_projections["feature"], normalize=False) features, output_dimension, is_training,
attention_projections[FEATURE_NAME], normalize=False)
return output_features return output_features
def compute_box_context_attention(box_features, context_features, def compute_box_context_attention(box_features, context_features,
valid_context_size, bottleneck_dimension, valid_context_size, bottleneck_dimension,
attention_temperature, is_training, freeze_batchnorm, attention_projections): attention_temperature, is_training,
freeze_batchnorm, attention_projections):
"""Computes the attention feature from the context given a batch of box. """Computes the attention feature from the context given a batch of box.
Args: Args:
...@@ -220,7 +219,7 @@ def compute_box_context_attention(box_features, context_features, ...@@ -220,7 +219,7 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows: softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature)) weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights. freeze_batchnorm: A boolean indicating whether to freeze batch normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions. attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns: Returns:
...@@ -232,7 +231,7 @@ def compute_box_context_attention(box_features, context_features, ...@@ -232,7 +231,7 @@ def compute_box_context_attention(box_features, context_features,
channels = box_features.shape[-1] channels = box_features.shape[-1]
if 'feature' not in attention_projections: if 'feature' not in attention_projections:
attention_projections['feature'] = ContextProjection(channels, freeze_batchnorm) attention_projections[FEATURE_NAME] = ContextProjection(channels, freeze_batchnorm)
# Average pools over height and width dimension so that the shape of # Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels]. # box_features becomes [batch_size, max_num_proposals, channels].
...@@ -241,7 +240,7 @@ def compute_box_context_attention(box_features, context_features, ...@@ -241,7 +240,7 @@ def compute_box_context_attention(box_features, context_features,
output_features = attention_block(box_features, context_features, output_features = attention_block(box_features, context_features,
bottleneck_dimension, channels, bottleneck_dimension, channels,
attention_temperature, valid_mask, attention_temperature, valid_mask,
is_training, freeze_batchnorm, attention_projections) is_training, attention_projections)
# Expands the dimension back to match with the original feature map. # Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :] output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
......
...@@ -80,7 +80,6 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -80,7 +80,6 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projected_features = context_rcnn_lib.project_features( projected_features = context_rcnn_lib.project_features(
features, features,
projection_dimension, projection_dimension,
freeze_batchnorm=False,
is_training=is_training, is_training=is_training,
normalize=normalize, normalize=normalize,
node=context_rcnn_lib.ContextProjection(projection_dimension, False)) node=context_rcnn_lib.ContextProjection(projection_dimension, False))
...@@ -101,12 +100,15 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -101,12 +100,15 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
context_features = tf.ones([2, 2, 3], tf.float32) context_features = tf.ones([2, 2, 3], tf.float32)
valid_mask = tf.constant([[True, True], [False, False]], tf.bool) valid_mask = tf.constant([[True, True], [False, False]], tf.bool)
is_training = False is_training = False
projection_layers = {"key": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "val": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
"query": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "feature": context_rcnn_lib.ContextProjection(output_dimension, False)} context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
#Add in the feature layer because this is further down the pipeline and it isn't automatically injected.
projection_layers['feature'] = context_rcnn_lib.ContextProjection(output_dimension, False)
output_features = context_rcnn_lib.attention_block( output_features = context_rcnn_lib.attention_block(
input_features, context_features, bottleneck_dimension, input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, is_training, False, projection_layers) output_dimension, attention_temperature, valid_mask, is_training, projection_layers)
# Makes sure the shape is correct. # Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 3, output_dimension]) self.assertAllEqual(output_features.shape, [2, 3, output_dimension])
...@@ -118,8 +120,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -118,8 +120,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
valid_context_size = tf.constant((2, 3), tf.int32) valid_context_size = tf.constant((2, 3), tf.int32)
bottleneck_dimension = 10 bottleneck_dimension = 10
attention_temperature = 1 attention_temperature = 1
projection_layers = {"key": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "val": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
"query": context_rcnn_lib.ContextProjection(bottleneck_dimension, False), "feature": context_rcnn_lib.ContextProjection(box_features.shape[-1], False)} context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
attention_features = context_rcnn_lib.compute_box_context_attention( attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size, box_features, context_features, valid_context_size,
bottleneck_dimension, attention_temperature, is_training, bottleneck_dimension, attention_temperature, is_training,
......
...@@ -75,8 +75,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -75,8 +75,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
return_raw_detections_during_predict=False, return_raw_detections_during_predict=False,
output_final_box_features=False, output_final_box_features=False,
attention_bottleneck_dimension=None, attention_bottleneck_dimension=None,
attention_temperature=None, attention_temperature=None):
attention_projections=None):
"""ContextRCNNMetaArch Constructor. """ContextRCNNMetaArch Constructor.
Args: Args:
...@@ -273,9 +272,12 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -273,9 +272,12 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
is_training=is_training, is_training=is_training,
freeze_batchnorm=freeze_batchnorm) freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {"key": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm), self._attention_projections = {'key': context_rcnn_lib.ContextProjection(
"val": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm), attention_bottleneck_dimension, freeze_batchnorm),
"query": context_rcnn_lib.ContextProjection(attention_bottleneck_dimension, freeze_batchnorm)} 'val': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm),
'query': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm)}
@staticmethod @staticmethod
def get_side_inputs(features): def get_side_inputs(features):
...@@ -334,13 +336,11 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -334,13 +336,11 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
features_to_crop, proposal_boxes_normalized, features_to_crop, proposal_boxes_normalized,
[self._initial_crop_size, self._initial_crop_size]) [self._initial_crop_size, self._initial_crop_size])
print(self._attention_projections)
attention_features = self._context_feature_extract_fn( attention_features = self._context_feature_extract_fn(
box_features=box_features, box_features=box_features,
context_features=context_features, context_features=context_features,
valid_context_size=valid_context_size, valid_context_size=valid_context_size,
attention_projections=self._attention_projections) attention_projections=self._attention_projections)
print(self._attention_projections)
# Adds box features with attention features. # Adds box features with attention features.
box_features += attention_features box_features += attention_features
......
...@@ -515,7 +515,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -515,7 +515,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
fields.InputDataFields.valid_context_size: valid_context_size fields.InputDataFields.valid_context_size: valid_context_size
} }
print("sep")
side_inputs = model.get_side_inputs(features) side_inputs = model.get_side_inputs(features)
prediction_dict = model.predict(preprocessed_inputs, true_image_shapes, prediction_dict = model.predict(preprocessed_inputs, true_image_shapes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment