".github/workflows/show_link_to_built_docs.yml" did not exist on "10f8ddb4f429af2b65e278b87e06bd1074959234"
Commit 33d05bec authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

clean slightly

parent cbd0576f
......@@ -99,7 +99,7 @@ def compute_valid_mask(num_valid_elements, num_elements):
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask
def project_features(features, projection_dimension, is_training, node=None, normalize=True):
def project_features(features, projection_dimension, is_training, freeze_batchnorm, node=None, normalize=True):
"""Projects features to another feature space.
Args:
......@@ -107,6 +107,8 @@ def project_features(features, projection_dimension, is_training, node=None, nor
num_features].
projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether tunable parameters for batch normalization should be frozen.
node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
......@@ -116,7 +118,7 @@ def project_features(features, projection_dimension, is_training, node=None, nor
if node is None:
node = {}
if 'batch_norm' not in node:
node['batch_norm'] = tf.keras.layers.BatchNormalization(epsilon=0.001, center=True, scale=True, momentum=0.97)
node['batch_norm'] = tf.keras.layers.BatchNormalization(epsilon=0.001, center=True, scale=True, momentum=0.97, trainable=(not freeze_batchnorm))
if 'projection' not in node:
print("Creating new projection")
node['projection'] = tf.keras.layers.Dense(units=projection_dimension,
......@@ -124,15 +126,6 @@ def project_features(features, projection_dimension, is_training, node=None, nor
use_bias=True)
# TODO(guanhangwu) Figure out a better way of specifying the batch norm
# params.
batch_norm_params = {
"is_training": is_training,
"decay": 0.97,
"epsilon": 0.001,
"center": True,
"scale": True
}
shape_arr = features.shape
batch_size = shape_arr[0]
feature_size = shape_arr[1]
......@@ -154,7 +147,7 @@ def project_features(features, projection_dimension, is_training, node=None, nor
def attention_block(input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask,
is_training, attention_projections):
is_training, freeze_batchnorm, attention_projections):
"""Generic attention block.
Args:
......@@ -171,6 +164,8 @@ def attention_block(input_features, context_features, bottleneck_dimension,
weights = exp(weights / temperature) / sum(exp(weights / temperature))
valid_mask: A boolean Tensor of shape [batch_size, context_size].
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns:
A float Tensor of shape [batch_size, input_size, output_dimension].
......@@ -178,11 +173,11 @@ def attention_block(input_features, context_features, bottleneck_dimension,
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, bottleneck_dimension, is_training, node=attention_projections["query"], normalize=True)
input_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["query"], normalize=True)
keys = project_features(
context_features, bottleneck_dimension, is_training, node=attention_projections["key"], normalize=True)
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["key"], normalize=True)
values = project_features(
context_features, bottleneck_dimension, is_training, node=attention_projections["val"], normalize=True)
context_features, bottleneck_dimension, is_training, freeze_batchnorm, node=attention_projections["val"], normalize=True)
print(attention_projections['query'])
weights = tf.matmul(queries, keys, transpose_b=True)
......@@ -193,13 +188,13 @@ def attention_block(input_features, context_features, bottleneck_dimension,
features = tf.matmul(weights, values)
output_features = project_features(
features, output_dimension, is_training, node=attention_projections["feature"], normalize=False)
features, output_dimension, is_training, freeze_batchnorm, node=attention_projections["feature"], normalize=False)
return output_features
def compute_box_context_attention(box_features, context_features,
valid_context_size, bottleneck_dimension,
attention_temperature, is_training, attention_projections):
attention_temperature, is_training, freeze_batchnorm, attention_projections):
"""Computes the attention feature from the context given a batch of box.
Args:
......@@ -215,6 +210,8 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze Batch Normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
Returns:
A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels].
......@@ -230,7 +227,7 @@ def compute_box_context_attention(box_features, context_features,
output_features = attention_block(box_features, context_features,
bottleneck_dimension, channels,
attention_temperature, valid_mask,
is_training, attention_projections)
is_training, freeze_batchnorm, attention_projections)
# Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
......
......@@ -80,6 +80,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projected_features = context_rcnn_lib.project_features(
features,
projection_dimension,
freeze_batchnorm=False,
is_training=is_training,
normalize=normalize)
......@@ -102,7 +103,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}}
output_features = context_rcnn_lib.attention_block(
input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, is_training, projection_layers)
output_dimension, attention_temperature, valid_mask, is_training, False, projection_layers)
# Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 3, output_dimension])
......@@ -117,7 +118,8 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projection_layers = {"key": {}, "val": {}, "query": {}, "feature": {}}
attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size,
bottleneck_dimension, attention_temperature, is_training, projection_layers)
bottleneck_dimension, attention_temperature, is_training,
False, projection_layers)
# Makes sure the shape is correct.
self.assertAllEqual(attention_features.shape, [2, 3, 1, 1, 4])
......
......@@ -270,7 +270,8 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
context_rcnn_lib.compute_box_context_attention,
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
is_training=is_training)
is_training=is_training,
freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {"key": {},
"val": {},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment