Commit 12110e64 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finalize context rcnn tf2

parent b44de920
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library functions for ContextRCNN."""
"""Library functions for Context R-CNN."""
import tensorflow as tf
from object_detection.core import freezable_batch_norm
......@@ -61,12 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
def build(self, input_shapes):
if not self._feature_proj:
self._output_dimension = input_shapes[0][-1]
self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_and_context_features, valid_context_size):
def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention."""
box_features, context_features = box_and_context_features
_, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size)
......
......@@ -106,7 +106,7 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
minval=0,
maxval=10,
dtype=tf.int32)
output_features = attention_block([input_features, context_features], valid_context_size)
output_features = attention_block(input_features, context_features, valid_context_size)
# Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 8, 1, 1, (output_dimension or 3)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment