Commit eb75e684 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

refactor further

parent 1a449bb9
......@@ -22,7 +22,7 @@ import unittest
from absl.testing import parameterized
import tensorflow.compat.v1 as tf
from object_detection.meta_architectures import context_rcnn_lib_v1 as context_rcnn_lib
from object_detection.meta_architectures import context_rcnn_lib
from object_detection.utils import test_case
from object_detection.utils import tf_version
......
......@@ -106,27 +106,15 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
#Add in the feature layer because this is further down the pipeline and it isn't automatically injected.
#projection_layers['feature'] = context_rcnn_lib.ContextProjection(output_dimension, False)
attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_temperature, False)
attention_block.set_output_dimension(output_dimension)
output_features = attention_block([input_features, context_features], is_training, valid_mask)
attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_temperature, False, output_dimension)
valid_context_size = tf.random_uniform((2,),
minval=0,
maxval=10,
dtype=tf.int32)
output_features = attention_block([input_features, context_features], is_training, valid_context_size)
# Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 3, output_dimension])
@parameterized.parameters(True, False)
def test_compute_box_context_attention(self, is_training):
box_features = tf.ones([2, 3, 4, 4, 4], tf.float32)
context_features = tf.ones([2, 5, 6], tf.float32)
valid_context_size = tf.constant((2, 3), tf.int32)
bottleneck_dimension = 10
attention_temperature = 1
attention_features = context_rcnn_lib.compute_box_context_attention(
box_features, context_features, valid_context_size,
is_training, context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_temperature, False))
# Makes sure the shape is correct.
self.assertAllEqual(attention_features.shape, [2, 3, 1, 1, 4])
if __name__ == '__main__':
tf.test.main()
......@@ -56,23 +56,41 @@ class AttentionBlock(tf.keras.layers.Layer):
self.key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.feature_proj = None
self.attention_temperature = attention_temperature
self.freeze_batchnorm = freeze_batchnorm
self.bottleneck_dimension = bottleneck_dimension
if output_dimension:
self.output_dimension = output_dimension
self.output_dimension = output_dimension
super(AttentionBlock, self).__init__(**kwargs)
def set_output_dimension(self, new_output_dimension):
self.output_dimension = new_output_dimension
def set_output_dimension(self, output_dim):
self.output_dimension = output_dim
def build(self, input_shapes):
self.feature_proj = ContextProjection(self.output_dimension,
self.freeze_batchnorm)
pass
def call(self, input_features, is_training, valid_mask):
def call(self, input_features, is_training, valid_context_size):
"""Handles a call by performing attention"""
print("CALLED")
input_features, context_features = input_features
print(input_features.shape)
_, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = input_features.shape[-1]
#Build the feature projection layer
if (not self.output_dimension):
self.output_dimension = channels
if (not self.feature_proj):
self.feature_proj = ContextProjection(self.output_dimension,
self.freeze_batchnorm)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
input_features = tf.reduce_mean(input_features, [2, 3])
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, self.bottleneck_dimension, is_training,
......@@ -94,6 +112,10 @@ class AttentionBlock(tf.keras.layers.Layer):
output_features = project_features(
features, self.output_dimension, is_training,
self.feature_proj, normalize=False)
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
print(output_features.shape)
return output_features
......@@ -197,45 +219,3 @@ def compute_valid_mask(num_valid_elements, num_elements):
num_valid_elements = num_valid_elements[..., tf.newaxis]
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask
def compute_box_context_attention(box_features, context_features,
valid_context_size, is_training,
attention_block):
"""Computes the attention feature from the context given a batch of box.
Args:
box_features: A float Tensor of shape [batch_size, max_num_proposals,
height, width, channels]. It is pooled features from first stage
proposals.
context_features: A float Tensor of shape [batch_size, context_size,
num_context_features].
valid_context_size: A int32 Tensor of shape [batch_size].
bottleneck_dimension: A int32 Tensor representing the bottleneck dimension
for intermediate projections.
attention_temperature: A float Tensor. It controls the temperature of the
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: Whether to freeze batch normalization weights.
attention_projections: Dictionary of the projection layers.
Returns:
A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels].
"""
_, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1]
attention_block.set_output_dimension(channels)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3])
output_features = attention_block([box_features, context_features],
is_training, valid_mask)
# Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
return output_features
......@@ -26,10 +26,10 @@ from __future__ import print_function
import functools
from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import context_rcnn_lib_v1, context_rcnn_lib_v2
from object_detection.meta_architectures import context_rcnn_lib, context_rcnn_lib_v2
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.utils import tf_version
import tensorflow as tf
class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
"""Context R-CNN Meta-architecture definition."""
......@@ -268,16 +268,14 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
if tf_version.is_tf1():
self._context_feature_extract_fn = functools.partial(
context_rcnn_lib_v1.compute_box_context_attention,
context_rcnn_lib.compute_box_context_attention,
bottleneck_dimension=attention_bottleneck_dimension,
attention_temperature=attention_temperature,
is_training=is_training)
else:
self._context_feature_extract_fn = functools.partial(
context_rcnn_lib_v2.compute_box_context_attention,
is_training=is_training,
attention_block=context_rcnn_lib_v2.AttentionBlock(
attention_bottleneck_dimension, attention_temperature, freeze_batchnorm))
self._attention_block = context_rcnn_lib_v2.AttentionBlock(
attention_bottleneck_dimension, attention_temperature, freeze_batchnorm)
self._is_training = is_training
@staticmethod
def get_side_inputs(features):
......@@ -333,14 +331,22 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
Returns:
A float32 Tensor with shape [K, new_height, new_width, depth].
"""
print("INSIDE META ARCH")
box_features = self._crop_and_resize_fn(
features_to_crop, proposal_boxes_normalized,
[self._initial_crop_size, self._initial_crop_size])
attention_features = self._context_feature_extract_fn(
box_features=box_features,
context_features=context_features,
valid_context_size=valid_context_size)
if tf_version.is_tf1():
attention_features = self._context_feature_extract_fn(
box_features=box_features,
context_features=context_features,
valid_context_size=valid_context_size)
else:
print("CALLING ATTENTION")
attention_features = self._attention_block([box_features, context_features], self._is_training, valid_context_size)
print(attention_features.shape)
# Adds box features with attention features.
box_features += attention_features
......
......@@ -438,8 +438,8 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
masks_are_class_agnostic=masks_are_class_agnostic,
share_box_across_classes=share_box_across_classes), **common_kwargs)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF2.X only test.')
@mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib_v1')
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
@mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib')
def test_prediction_mock_tf1(self, mock_context_rcnn_lib_v1):
"""Mocks the context_rcnn_lib_v1 module to test the prediction.
......@@ -480,7 +480,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
_ = model.predict(preprocessed_inputs, true_image_shapes, **side_inputs)
mock_context_rcnn_lib_v1.compute_box_context_attention.assert_called_once()
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF1.X only test.')
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
@mock.patch.object(context_rcnn_meta_arch, 'context_rcnn_lib_v2')
def test_prediction_mock_tf2(self, mock_context_rcnn_lib_v2):
"""Mocks the context_rcnn_lib_v2 module to test the prediction.
......@@ -499,6 +499,7 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
mock_tensor = tf.ones([2, 8, 3, 3, 3], tf.float32)
mock_context_rcnn_lib_v2.compute_box_context_attention.return_value = mock_tensor
print(mock_context_rcnn_lib_v2.compute_box_context_attention)
inputs_shape = (2, 20, 20, 3)
inputs = tf.cast(
tf.random_uniform(inputs_shape, minval=0, maxval=255, dtype=tf.int32),
......@@ -518,9 +519,9 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
}
side_inputs = model.get_side_inputs(features)
print("Predicting now")
_ = model.predict(preprocessed_inputs, true_image_shapes, **side_inputs)
mock_context_rcnn_lib_v2.compute_box_context_attention.assert_called_once()
#mock_context_rcnn_lib_v2.compute_box_context_attention.assert_called_once()
@parameterized.named_parameters(
{'testcase_name': 'static_shapes', 'static_shapes': True},
......
import abc
import collections
import functools
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection.core import box_list
from object_detection.core import box_list_ops
from object_detection.core import keypoint_ops
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner as cn_assigner
from object_detection.utils import shape_utils
class DETRMetaArch(model.DetectionModel):
def __init__():
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment