Commit 33d05bec authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

clean slightly

parent cbd0576f
...@@ -99,7 +99,7 @@ def compute_valid_mask(num_valid_elements, num_elements): ...@@ -99,7 +99,7 @@ def compute_valid_mask(num_valid_elements, num_elements):
valid_mask = tf.less(batch_element_idxs, num_valid_elements) valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask return valid_mask
def project_features(features, projection_dimension, is_training, node=None, normalize=True): def project_features(features, projection_dimension, is_training, freeze_batchnorm, node=None, normalize=True):
"""Projects features to another feature space. """Projects features to another feature space.
Args: Args:
...@@ -107,6 +107,8 @@ def project_features(features, projection_dimension, is_training, node=None, nor ...@@ -107,6 +107,8 @@ def project_features(features, projection_dimension, is_training, node=None, nor
num_features]. num_features].
projection_dimension: A int32 Tensor. projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether tunable parameters for batch normalization should be frozen.
node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2 normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension. normalized on the last dimension.
...@@ -116,7 +118,7 @@ def project_features(features, projection_dimension, is_training, node=None, nor ...@@ -116,7 +118,7 @@ def project_features(features, projection_dimension, is_training, node=None, nor
if node is None: if node is None:
node = {} node = {}
if 'batch_norm' not in node: if 'batch_norm' not in node:
node['batch_norm'] = tf.keras.layers.BatchNormalization(epsilon=0.001, center=True, scale=True, momentum=0.97) node['batch_norm'] = tf.keras.layers.BatchNormalization(epsilon=0.001, center=True, scale=True, momentum=0.97, trainable=(not freeze_batchnorm))
if 'projection' not in node: if 'projection' not in node:
print("Creating new projection") print("Creating new projection")
node['projection'] = tf.keras.layers.Dense(units=projection_dimension, node['projection'] = tf.keras.layers.Dense(units=projection_dimension,
...@@ -124,15 +126,6 @@ def project_features(features, projection_dimension, is_training, node=None, nor ...@@ -124,15 +126,6 @@ def project_features(features, projection_dimension, is_training, node=None, nor
use_bias=True) use_bias=True)
# TODO(guanhangwu) Figure out a better way of specifying the batch norm
# params.
batch_norm_params = {
"is_training": is_training,
"decay": 0.97,
"epsilon": 0.001,
"center": True,
"scale": True
}
shape_arr = features.shape shape_arr = features.shape
batch_size = shape_arr[0] batch_size = shape_arr[0]
feature_size = shape_arr[1] feature_size = shape_arr[1]
...@@ -154,7 +147,7 @@ def project_features(features, projection_dimension, is_training, node=None, nor ...@@ -154,7 +147,7 @@ def project_features(features, projection_dimension, is_training, node=None, nor
def attention_block(input_features, context_features, bottleneck_dimension, def attention_block(input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, output_dimension, attention_temperature, valid_mask,
is_training, attention_projections): is_training, freeze_batchnorm, attention_projections):
"""Generic attention block. """Generic attention block.
Args: Args:
...@@ -171,6 +164,8 @@ def attention_block(input_features, context_features, bottleneck_dimension, ...@@ -171,6 +164,8 @@ def attention_block(input_features, context_features, bottleneck_dimension,
weights = exp(weights / temperature) / sum(exp(weights / temperature)) weights = exp(weights / temperature) / sum(exp(weights / temperature))
valid_mask: A boolean Tensor of shape [batch_size, context_size]. valid_mask: A boolean Tensor of shape [batch_size, context_size].
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns: Returns:
A float Tensor of shape [batch_size, input_size, output_dimension]. A float Tensor of shape [batch_size, input_size, output_dimension].
...@@ -178,11 +173,11 @@ def attention_block(input_features, context_features, bottleneck_dimension, ...@@ -178,11 +173,11 @@ def attention_block(input_features, context_features, bottleneck_dimension,
with tf.variable_scope("AttentionBlock"): with tf.variable_scope("AttentionBlock"):
queries = project_features( queries = project_features(
input_features, bottleneck_dimension, is_training, node=attention_projections["query"], normalize=True) input_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["query"], normalize=True)
keys = project_features( keys = project_features(
context_features, bottleneck_dimension, is_training, node=attention_projections["key"], normalize=True) context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["key"], normalize=True)
values = project_features( values = project_features(
context_features, bottleneck_dimension, is_training, node=attention_projections["val"], normalize=True) context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["val"], normalize=True)
print(attention_projections['query']) print(attention_projections['query'])
weights = tf.matmul(queries, keys, transpose_b=True) weights = tf.matmul(queries, keys, transpose_b=True)
...@@ -193,13 +188,13 @@ def attention_block(input_features, context_features, bottleneck_dimension, ...@@ -193,13 +188,13 @@ def attention_block(input_features, context_features, bottleneck_dimension,
features = tf.matmul(weights, values) features = tf.matmul(weights, values)
output_features = project_features( output_features = project_features(
features, output_dimension, is_training, node=attention_projections["feature"], normalize=False) features, output_dimension, is_training, freeze_batchnorm, node=attention_projections["feature"], normalize=False)
return output_features return output_features
def compute_box_context_attention(box_features, context_features, def compute_box_context_attention(box_features, context_features,
valid_context_size, bottleneck_dimension, valid_context_size, bottleneck_dimension,
attention_temperature, is_training, attention_projections): attention_temperature, is_training, freeze_batchnorm, attention_projections):
"""Computes the attention feature from the context given a batch of box. """Computes the attention feature from the context given a batch of box.
Args: Args:
...@@ -215,6 +210,8 @@ def compute_box_context_attention(box_features, context_features, ...@@ -215,6 +210,8 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows: softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature)) weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns: Returns:
A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels]. A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels].
...@@ -230,7 +227,7 @@ def compute_box_context_attention(box_features, context_features, ...@@ -230,7 +227,7 @@ def compute_box_context_attention(box_features, context_features,
output_features = attention_block(box_features, context_features, output_features = attention_block(box_features, context_features,
bottleneck_dimension, channels, bottleneck_dimension, channels,
attention_temperature, valid_mask, attention_temperature, valid_mask,
is_training, attention_projections) is_training, freeze_batchnorm, attention_projections)
# Expands the dimension back to match with the original feature map. # Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :] output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
......
...@@ -80,6 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -80,6 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projected_features = context_rcnn_lib.project_features( projected_features = context_rcnn_lib.project_features(
features, features,
projection_dimension, projection_dimension,
freeze_batchnorm=False,
is_training=is_training, is_training=is_training,
normalize=normalize) normalize=normalize)
...@@ -102,7 +103,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -102,7 +103,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}} projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}}
output_features = context_rcnn_lib.attention_block( output_features = context_rcnn_lib.attention_block(
input_features, context_features, bottleneck_dimension, input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, is_training, projection_layers) output_dimension, attention_temperature, valid_mask, is_training, False, projection_layers)
# Makes sure the shape is correct. # Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 3, output_dimension]) self.assertAllEqual(output_features.shape, [2, 3, output_dimension])
...@@ -117,7 +118,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -117,7 +118,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}} projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}}
attention_features = context_rcnn_lib.compute_box_context_attention( attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size, box_features, context_features, valid_context_size,
bottleneck_dimension, attention_temperature, is_training, projection_layers) bottleneck_dimension, attention_temperature, is_training,
False, projection_layers)
# Makes sure the shape is correct. # Makes sure the shape is correct.
self.assertAllEqual(attention_features.shape, [2, 3, 1, 1, 4]) self.assertAllEqual(attention_features.shape, [2, 3, 1, 1, 4])
......
...@@ -270,7 +270,8 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -270,7 +270,8 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
context_rcnn_lib.compute_box_context_attention, context_rcnn_lib.compute_box_context_attention,
bottleneck_dimension=attention_bottleneck_dimension, bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature, attention_temperature=attention_temperature,
is_training=is_training) is_training=is_training,
freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {"key": {}, self._attention_projections = {"key": {},
"val": {}, "val": {},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment