Commit 06e15650 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

context rcnn meta arch

parent cca97db0
...@@ -52,7 +52,8 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -52,7 +52,8 @@ class ContextProjection(tf.keras.layers.Layer):
class AttentionBlock(tf.keras.layers.Layer): class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention.""" """Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature, def __init__(self, bottleneck_dimension, attention_temperature,
freeze_batchnorm, output_dimension=None, is_training=False, **kwargs): freeze_batchnorm, output_dimension=None,
is_training=False, **kwargs):
self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
...@@ -77,9 +78,9 @@ class AttentionBlock(tf.keras.layers.Layer): ...@@ -77,9 +78,9 @@ class AttentionBlock(tf.keras.layers.Layer):
channels = input_features.shape[-1] channels = input_features.shape[-1]
#Build the feature projection layer #Build the feature projection layer
if (not self._output_dimension): if not self._output_dimension:
self._output_dimension = channels self._output_dimension = channels
if (not self._feature_proj): if not self._feature_proj:
self._feature_proj = ContextProjection(self._output_dimension, self._feature_proj = ContextProjection(self._output_dimension,
self._freeze_batchnorm) self._freeze_batchnorm)
......
...@@ -337,7 +337,8 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -337,7 +337,8 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
[self._initial_crop_size, self._initial_crop_size]) [self._initial_crop_size, self._initial_crop_size])
attention_features = self._context_feature_extract_fn( attention_features = self._context_feature_extract_fn(
box_features, context_features, box_features,
context_features,
valid_context_size=valid_context_size) valid_context_size=valid_context_size)
# Adds box features with attention features. # Adds box features with attention features.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment