"vscode:/vscode.git/clone" did not exist on "db77dfe01f7ee7bff8799a103406828974136f22"
Commit fb41b048 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

make minor changes

parent 6b18b06b
......@@ -71,11 +71,11 @@ class AttentionBlock(tf.keras.layers.Layer):
def build(self, input_shapes):
pass
def call(self, input_features, context_features, valid_context_size):
def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention."""
_, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = input_features.shape[-1]
channels = box_features.shape[-1]
#Build the feature projection layer
if not self._output_dimension:
......@@ -86,11 +86,11 @@ class AttentionBlock(tf.keras.layers.Layer):
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
input_features = tf.reduce_mean(input_features, [2, 3])
box_features = tf.reduce_mean(box_features, [2, 3])
with tf.name_scope("AttentionBlock"):
queries = project_features(
input_features, self._bottleneck_dimension, self._is_training,
box_features, self._bottleneck_dimension, self._is_training,
self._query_proj, normalize=True)
keys = project_features(
context_features, self._bottleneck_dimension, self._is_training,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment