"vscode:/vscode.git/clone" did not exist on "6b3a644be8afa8a0cc4435b5494e994db3175e23"
Commit b44de920 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finalize tf2 context rcnn

parent f72dfa8d
...@@ -61,11 +61,13 @@ class AttentionBlock(tf.keras.layers.Layer): ...@@ -61,11 +61,13 @@ class AttentionBlock(tf.keras.layers.Layer):
def build(self, input_shapes): def build(self, input_shapes):
if not self._feature_proj: if not self._feature_proj:
self._output_dimension = input_shapes[-1] self._output_dimension = input_shapes[0][-1]
self._feature_proj = ContextProjection(self._output_dimension) self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_features, context_features, valid_context_size): def call(self, box_and_context_features, valid_context_size):
"""Handles a call by performing attention.""" """Handles a call by performing attention."""
box_features, context_features = box_and_context_features
_, context_size, _ = context_features.shape _, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size) valid_mask = compute_valid_mask(valid_context_size, context_size)
......
...@@ -90,9 +90,9 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase): ...@@ -90,9 +90,9 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
@parameterized.parameters( @parameterized.parameters(
(2, 10, 1), (2, 10, 1),
(3, 10, 2), (3, 10, 2),
(4, 20, 3), (4, None, 3),
(5, 20, 4), (5, 20, 4),
(7, 20, 5), (7, None, 5),
) )
def test_attention_block(self, bottleneck_dimension, output_dimension, def test_attention_block(self, bottleneck_dimension, output_dimension,
attention_temperature): attention_temperature):
...@@ -106,10 +106,10 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase): ...@@ -106,10 +106,10 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
minval=0, minval=0,
maxval=10, maxval=10,
dtype=tf.int32) dtype=tf.int32)
output_features = attention_block(input_features, context_features, valid_context_size) output_features = attention_block([input_features, context_features], valid_context_size)
# Makes sure the shape is correct. # Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 8, 1, 1, output_dimension]) self.assertAllEqual(output_features.shape, [2, 8, 1, 1, (output_dimension or 3)])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment