Commit 65bd772d authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

style

parent 8e3eb8d4
......@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tf_slim as slim
# The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000
......@@ -30,8 +29,10 @@ QUERY_NAME = 'query'
FEATURE_NAME = 'feature'
class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(epsilon=0.001,
self.batch_norm = tf.keras.layers.BatchNormalization(
epsilon=0.001,
center=True,
scale=True,
momentum=0.97,
......@@ -39,14 +40,14 @@ class ContextProjection(tf.keras.layers.Layer):
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
super(ContextProjection,self).__init__(**kwargs)
super(ContextProjection, self).__init__(**kwargs)
def build(self, input_shape):
self.batch_norm.build(input_shape)
self.projection.build(input_shape)
def call(self, input):
return self.projection(self.batch_norm(input))
def call(self, input_features, is_training=False):
return self.projection(self.batch_norm(input_features, is_training))
def filter_weight_value(weights, values, valid_mask):
......@@ -119,7 +120,8 @@ def compute_valid_mask(num_valid_elements, num_elements):
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask
def project_features(features, projection_dimension, is_training, node, normalize=True):
def project_features(features, projection_dimension,
is_training, node, normalize=True):
"""Projects features to another feature space.
Args:
......@@ -127,7 +129,8 @@ def project_features(features, projection_dimension, is_training, node, normaliz
num_features].
projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
node: Contains two layers (Batch Normalization and Dense) specific to the particular operation being performed (key, value, query, features)
node: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
......@@ -135,12 +138,10 @@ def project_features(features, projection_dimension, is_training, node, normaliz
A float Tensor of shape [batch, features_size, projection_dimension].
"""
shape_arr = features.shape
batch_size = shape_arr[0]
feature_size = shape_arr[1]
num_features = shape_arr[2]
batch_size, _, num_features = shape_arr
features = tf.reshape(features, [-1, num_features])
projected_features = node(features)
projected_features = node(features, is_training)
projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension])
......@@ -219,8 +220,8 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: A boolean indicating whether to freeze batch normalization weights.
attention_projections: Contains a dictionary of the batch norm and projection functions.
freeze_batchnorm: Whether to freeze batch normalization weights.
attention_projections: Dictionary of the projection layers.
Returns:
A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels].
......@@ -231,7 +232,8 @@ def compute_box_context_attention(box_features, context_features,
channels = box_features.shape[-1]
if 'feature' not in attention_projections:
attention_projections[FEATURE_NAME] = ContextProjection(channels, freeze_batchnorm)
attention_projections[FEATURE_NAME] = ContextProjection(
channels, freeze_batchnorm)
# Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels].
......
......@@ -272,7 +272,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
is_training=is_training,
freeze_batchnorm=freeze_batchnorm)
self._attention_projections = {'key': context_rcnn_lib.ContextProjection(
self._atten_projs = {'key': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm),
'val': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm),
......@@ -340,7 +340,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
box_features=box_features,
context_features=context_features,
valid_context_size=valid_context_size,
attention_projections=self._attention_projections)
attention_projections=self._atten_projs)
# Adds box features with attention features.
box_features += attention_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment