Commit 65bd772d authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

style

parent 8e3eb8d4
...@@ -19,7 +19,6 @@ from __future__ import division ...@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim
# The negative value used in padding the invalid weights. # The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000 _NEGATIVE_PADDING_VALUE = -100000
...@@ -30,8 +29,10 @@ QUERY_NAME = 'query' ...@@ -30,8 +29,10 @@ QUERY_NAME = 'query'
FEATURE_NAME = 'feature' FEATURE_NAME = 'feature'
class ContextProjection(tf.keras.layers.Layer): class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs): def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=0.001, self.batch_norm = tf.keras.layers.BatchNormalization(
epsilon=0.001,
center=True, center=True,
scale=True, scale=True,
momentum=0.97, momentum=0.97,
...@@ -39,14 +40,14 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -39,14 +40,14 @@ class ContextProjection(tf.keras.layers.Layer):
self.projection = tf.keras.layers.Dense(units=projection_dimension, self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6, activation=tf.nn.relu6,
use_bias=True) use_bias=True)
super(ContextProjection,self).__init__(**kwargs) super(ContextProjection, self).__init__(**kwargs)
def build(self, input_shape): def build(self, input_shape):
self.batch_norm.build(input_shape) self.batch_norm.build(input_shape)
self.projection.build(input_shape) self.projection.build(input_shape)
def call(self, input): def call(self, input_features, is_training=False):
return self.projection(self.batch_norm(input)) return self.projection(self.batch_norm(input_features, is_training))
def filter_weight_value(weights, values, valid_mask): def filter_weight_value(weights, values, valid_mask):
...@@ -119,7 +120,8 @@ def compute_valid_mask(num_valid_elements, num_elements): ...@@ -119,7 +120,8 @@ def compute_valid_mask(num_valid_elements, num_elements):
valid_mask = tf.less(batch_element_idxs, num_valid_elements) valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask return valid_mask
def project_features(features, projection_dimension, is_training, node, normalize=True): def project_features(features, projection_dimension,
is_training, node, normalize=True):
"""Projects features to another feature space. """Projects features to another feature space.
Args: Args:
...@@ -127,7 +129,8 @@ def project_features(features, projection_dimension, is_training, node, normaliz ...@@ -127,7 +129,8 @@ def project_features(features, projection_dimension, is_training, node, normaliz
num_features]. num_features].
projection_dimension: A int32 Tensor. projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features) node: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2 normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension. normalized on the last dimension.
...@@ -135,12 +138,10 @@ def project_features(features, projection_dimension, is_training, node, normaliz ...@@ -135,12 +138,10 @@ def project_features(features, projection_dimension, is_training, node, normaliz
A float Tensor of shape [batch, features_size, projection_dimension]. A float Tensor of shape [batch, features_size, projection_dimension].
""" """
shape_arr = features.shape shape_arr = features.shape
batch_size = shape_arr[0] batch_size, _, num_features = shape_arr
feature_size = shape_arr[1]
num_features = shape_arr[2]
features = tf.reshape(features, [-1, num_features]) features = tf.reshape(features, [-1, num_features])
projected_features = node(features) projected_features = node(features, is_training)
projected_features = tf.reshape(projected_features, projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension]) [batch_size, -1, projection_dimension])
...@@ -219,8 +220,8 @@ def compute_box_context_attention(box_features, context_features, ...@@ -219,8 +220,8 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows: softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature)) weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze batch normalization weights. freeze_batchnorm: Whether to freeze batch normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions. attention_projections: Dictionary of the projection layers.
Returns: Returns:
A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels]. A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels].
...@@ -231,7 +232,8 @@ def compute_box_context_attention(box_features, context_features, ...@@ -231,7 +232,8 @@ def compute_box_context_attention(box_features, context_features,
channels = box_features.shape[-1] channels = box_features.shape[-1]
if 'feature' not in attention_projections: if 'feature' not in attention_projections:
attention_projections[FEATURE_NAME] = ContextProjection(channels, freeze_batchnorm) attention_projections[FEATURE_NAME] = ContextProjection(
channels, freeze_batchnorm)
# Average pools over height and width dimension so that the shape of # Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels]. # box_features becomes [batch_size, max_num_proposals, channels].
......
...@@ -272,7 +272,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -272,7 +272,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
is_training=is_training, is_training=is_training,
freeze_batchnorm=freeze_batchnorm) freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {'key': context_rcnn_lib.ContextProjection( self._atten_projs = {'key': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm), attention_bottleneck_dimension, freeze_batchnorm),
'val': context_rcnn_lib.ContextProjection( 'val': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm), attention_bottleneck_dimension, freeze_batchnorm),
...@@ -340,7 +340,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -340,7 +340,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
box_features=box_features, box_features=box_features,
context_features=context_features, context_features=context_features,
valid_context_size=valid_context_size, valid_context_size=valid_context_size,
attention_projections=self._attention_projections) attention_projections=self._atten_projs)
# Adds box features with attention features. # Adds box features with attention features.
box_features += attention_features box_features += attention_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment