"magic_pdf/vscode:/vscode.git/clone" did not exist on "6b6f40f3501b60d35dc82f42a2169f50e1132ac2"
Commit 35fda973 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

work on fixing

parent 66e8a904
......@@ -50,24 +50,56 @@ class ContextProjection(tf.keras.layers.Layer):
return self.projection(self.batch_norm(input_features, is_training))
class AttentionBlock(tf.keras.layers.Layer):
def __init__(self, bottleneck_dimension, attention_temperature, freeze_batchnorm, **kwargs):
def __init__(self, bottleneck_dimension, attention_temperature, freeze_batchnorm, output_dimension=None, **kwargs):
self.key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.attention_temperature = attention_temperature
self.freeze_batchnorm = freeze_batchnorm
self.bottleneck_dimension = bottleneck_dimension
if output_dimension:
self.output_dimension = output_dimension
super(AttentionBlock, self).__init__(**kwargs)
def set_output_dimension(self, new_output_dimension):
self.output_dimension = new_output_dimension
def build(self, input_shapes):
self.feature_proj = ContextProjection(input_shapes[0][-1], self.freeze_batchnorm)
print(input_shapes)
self.feature_proj = ContextProjection(self.output_dimension, self.freeze_batchnorm)
#self.key_proj.build(input_shapes[0])
#self.val_proj.build(input_shapes[0])
#self.query_proj.build(input_shapes[0])
#self.feature_proj.build(input_shapes[0])
pass
def filter_weight_value(self, weights, values, valid_mask):
def call(self, input_features, is_training, valid_mask):
input_features, context_features = input_features
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, self.bottleneck_dimension, is_training,
self.query_proj, normalize=True)
keys = project_features(
context_features, self.bottleneck_dimension, is_training,
self.key_proj, normalize=True)
values = project_features(
context_features, self.bottleneck_dimension, is_training,
self.val_proj, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / self.attention_temperature)
features = tf.matmul(weights, values)
output_features = project_features(
features, self.output_dimension, is_training,
self.feature_proj, normalize=False)
return output_features
def filter_weight_value(weights, values, valid_mask):
"""Filters weights and values based on valid_mask.
_NEGATIVE_PADDING_VALUE will be added to invalid elements in the weights to
......@@ -118,7 +150,7 @@ class AttentionBlock(tf.keras.layers.Layer):
return weights, values
def run_projection(self, features, bottleneck_dimension, is_training, layer, normalize=True):
def project_features(features, bottleneck_dimension, is_training, layer, normalize=True):
"""Projects features to another feature space.
Args:
......@@ -149,32 +181,6 @@ class AttentionBlock(tf.keras.layers.Layer):
return projected_features
def call(self, input_features, is_training, valid_mask):
input_features, context_features = input_features
with tf.variable_scope("AttentionBlock"):
queries = self.run_projection(
input_features, self.bottleneck_dimension, is_training,
self.query_proj, normalize=True)
keys = self.run_projection(
context_features, self.bottleneck_dimension, is_training,
self.key_proj, normalize=True)
values = self.run_projection(
context_features, self.bottleneck_dimension, is_training,
self.val_proj, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = self.filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / self.attention_temperature)
features = tf.matmul(weights, values)
output_features = self.run_projection(
features, input_features.shape[-1], is_training,
self.feature_proj, normalize=False)
return output_features
def compute_valid_mask(num_valid_elements, num_elements):
"""Computes mask of valid entries within padded context feature.
......@@ -222,6 +228,7 @@ def compute_box_context_attention(box_features, context_features,
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1]
attention_block.set_output_dimension(channels)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
......
......@@ -80,9 +80,9 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
projected_features = context_rcnn_lib.project_features(
features,
projection_dimension,
is_training=is_training,
normalize=normalize,
node=context_rcnn_lib.ContextProjection(projection_dimension, False))
is_training,
context_rcnn_lib.ContextProjection(projection_dimension, False),
normalize=normalize)
# Makes sure the shape is correct.
self.assertAllEqual(projected_features.shape, [2, 3, projection_dimension])
......@@ -100,15 +100,15 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
context_features = tf.ones([2, 2, 3], tf.float32)
valid_mask = tf.constant([[True, True], [False, False]], tf.bool)
is_training = False
projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
#projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
# context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
#Add in the feature layer because this is further down the pipeline and it isn't automatically injected.
projection_layers['feature'] = context_rcnn_lib.ContextProjection(output_dimension, False)
#projection_layers['feature'] = context_rcnn_lib.ContextProjection(output_dimension, False)
output_features = context_rcnn_lib.attention_block(
input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask, is_training, projection_layers)
attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_temperature, False)
attention_block.set_output_dimension(output_dimension)
output_features = attention_block([input_features, context_features], is_training, valid_mask)
# Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 3, output_dimension])
......@@ -120,12 +120,11 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
valid_context_size = tf.constant((2, 3), tf.int32)
bottleneck_dimension = 10
attention_temperature = 1
projection_layers = {context_rcnn_lib.KEY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False), context_rcnn_lib.VALUE_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False),
context_rcnn_lib.QUERY_NAME: context_rcnn_lib.ContextProjection(bottleneck_dimension, False)}
attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size,
bottleneck_dimension, attention_temperature, is_training,
False, projection_layers)
False, context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_temperature, False))
# Makes sure the shape is correct.
self.assertAllEqual(attention_features.shape, [2, 3, 1, 1, 4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment