Commit 06e15650 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

context rcnn meta arch

parent cca97db0
......@@ -52,7 +52,8 @@ class ContextProjection(tf.keras.layers.Layer):
class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature,
freeze_batchnorm, output_dimension=None, is_training=False, **kwargs):
freeze_batchnorm, output_dimension=None,
is_training=False, **kwargs):
self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
......@@ -77,16 +78,16 @@ class AttentionBlock(tf.keras.layers.Layer):
channels = input_features.shape[-1]
#Build the feature projection layer
if (not self._output_dimension):
if not self._output_dimension:
self._output_dimension = channels
if (not self._feature_proj):
if not self._feature_proj:
self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm)
self._freeze_batchnorm)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
input_features = tf.reduce_mean(input_features, [2, 3])
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, self._bottleneck_dimension, self._is_training,
......
......@@ -337,7 +337,8 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
[self._initial_crop_size, self._initial_crop_size])
attention_features = self._context_feature_extract_fn(
box_features, context_features,
box_features,
context_features,
valid_context_size=valid_context_size)
# Adds box features with attention features.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment