Commit e3f88e11 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

make fixes

parent e9f620af
......@@ -16,18 +16,15 @@
"""Library functions for ContextRCNN."""
import tensorflow as tf
from object_detection.core import freezable_batch_norm
# The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000
KEY_NAME = 'key'
VALUE_NAME = 'val'
QUERY_NAME = 'query'
FEATURE_NAME = 'feature'
class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(
self.batch_norm = freezable_batch_norm.FreezableBatchNorm(
epsilon=0.001,
center=True,
scale=True,
......@@ -49,7 +46,7 @@ class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature,
freeze_batchnorm, output_dimension=None,
is_training=False, **kwargs):
is_training=False, name='AttentionBlock', **kwargs):
self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
......@@ -57,15 +54,18 @@ class AttentionBlock(tf.keras.layers.Layer):
self._attention_temperature = attention_temperature
self._freeze_batchnorm = freeze_batchnorm
self._bottleneck_dimension = bottleneck_dimension
self._output_dimension = output_dimension
self._is_training = is_training
super(AttentionBlock, self).__init__(**kwargs)
def set_output_dimension(self, output_dim):
self._output_dimension = output_dim
self._output_dimension = output_dimension
if self._output_dimension:
self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm)
super(AttentionBlock, self).__init__(name=name, **kwargs)
def build(self, input_shapes):
pass
if not self._feature_proj:
self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm)
def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention."""
......@@ -73,13 +73,6 @@ class AttentionBlock(tf.keras.layers.Layer):
valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1]
#Build the feature projection layer
if not self._output_dimension:
self._output_dimension = channels
if not self._feature_proj:
self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3])
......
......@@ -28,8 +28,7 @@ from object_detection.utils import tf_version
_NEGATIVE_PADDING_VALUE = -100000
class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
tf.test.TestCase):
class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase):
"""Tests for the functions in context_rcnn_lib."""
def test_compute_valid_mask(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment