Commit 12110e64 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finalize context rcnn tf2

parent b44de920
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Library functions for ContextRCNN.""" """Library functions for Context R-CNN."""
import tensorflow as tf import tensorflow as tf
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
...@@ -61,12 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer): ...@@ -61,12 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
def build(self, input_shapes): def build(self, input_shapes):
if not self._feature_proj: if not self._feature_proj:
self._output_dimension = input_shapes[0][-1] self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension) self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_and_context_features, valid_context_size): def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention.""" """Handles a call by performing attention."""
box_features, context_features = box_and_context_features
_, context_size, _ = context_features.shape _, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size) valid_mask = compute_valid_mask(valid_context_size, context_size)
......
...@@ -106,7 +106,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase): ...@@ -106,7 +106,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
minval=0, minval=0,
maxval=10, maxval=10,
dtype=tf.int32) dtype=tf.int32)
output_features = attention_block([input_features, context_features], valid_context_size) output_features = attention_block(input_features, context_features, valid_context_size)
# Makes sure the shape is correct. # Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 8, 1, 1, (output_dimension or 3)]) self.assertAllEqual(output_features.shape, [2, 8, 1, 1, (output_dimension or 3)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment