Commit 1a449bb9 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

add back context_rcnn_lib

parent c81c01b2
...@@ -19,85 +19,12 @@ from __future__ import division ...@@ -19,85 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim
# The negative value used in padding the invalid weights. # The negative value used in padding the invalid weights.
_NEGATIVE_PADDING_VALUE = -100000 _NEGATIVE_PADDING_VALUE = -100000
KEY_NAME = 'key'
VALUE_NAME = 'val'
QUERY_NAME = 'query'
FEATURE_NAME = 'feature'
class ContextProjection(tf.keras.layers.Layer):
"""Custom layer to do batch normalization and projection."""
def __init__(self, projection_dimension, freeze_batchnorm, **kwargs):
self.batch_norm = tf.keras.layers.BatchNormalization(
epsilon=0.001,
center=True,
scale=True,
momentum=0.97,
trainable=(not freeze_batchnorm))
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
super(ContextProjection, self).__init__(**kwargs)
def build(self, input_shape):
self.batch_norm.build(input_shape)
self.projection.build(input_shape)
def call(self, input_features, is_training=False):
return self.projection(self.batch_norm(input_features, is_training))
class AttentionBlock(tf.keras.layers.Layer):
def __init__(self, bottleneck_dimension, attention_temperature, freeze_batchnorm, output_dimension=None, **kwargs):
self.key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.attention_temperature = attention_temperature
self.freeze_batchnorm = freeze_batchnorm
self.bottleneck_dimension = bottleneck_dimension
if output_dimension:
self.output_dimension = output_dimension
super(AttentionBlock, self).__init__(**kwargs)
def set_output_dimension(self, new_output_dimension):
self.output_dimension = new_output_dimension
def build(self, input_shapes):
print(input_shapes)
self.feature_proj = ContextProjection(self.output_dimension, self.freeze_batchnorm)
#self.key_proj.build(input_shapes[0])
#self.val_proj.build(input_shapes[0])
#self.query_proj.build(input_shapes[0])
#self.feature_proj.build(input_shapes[0])
pass
def call(self, input_features, is_training, valid_mask):
input_features, context_features = input_features
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, self.bottleneck_dimension, is_training,
self.query_proj, normalize=True)
keys = project_features(
context_features, self.bottleneck_dimension, is_training,
self.key_proj, normalize=True)
values = project_features(
context_features, self.bottleneck_dimension, is_training,
self.val_proj, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / self.attention_temperature)
features = tf.matmul(weights, values)
output_features = project_features(
features, self.output_dimension, is_training,
self.feature_proj, normalize=False)
return output_features
def filter_weight_value(weights, values, valid_mask): def filter_weight_value(weights, values, valid_mask):
"""Filters weights and values based on valid_mask. """Filters weights and values based on valid_mask.
...@@ -126,15 +53,15 @@ def filter_weight_value(weights, values, valid_mask): ...@@ -126,15 +53,15 @@ def filter_weight_value(weights, values, valid_mask):
m_batch_size, m_context_size = valid_mask.shape m_batch_size, m_context_size = valid_mask.shape
if w_batch_size != v_batch_size or v_batch_size != m_batch_size: if w_batch_size != v_batch_size or v_batch_size != m_batch_size:
raise ValueError("Please make sure the first dimension of the input" raise ValueError("Please make sure the first dimension of the input"
" tensors are the same.") " tensors are the same.")
if w_context_size != v_context_size: if w_context_size != v_context_size:
raise ValueError("Please make sure the third dimension of weights matches" raise ValueError("Please make sure the third dimension of weights matches"
" the second dimension of values.") " the second dimension of values.")
if w_context_size != m_context_size: if w_context_size != m_context_size:
raise ValueError("Please make sure the third dimension of the weights" raise ValueError("Please make sure the third dimension of the weights"
" matches the second dimension of the valid_mask.") " matches the second dimension of the valid_mask.")
valid_mask = valid_mask[..., tf.newaxis] valid_mask = valid_mask[..., tf.newaxis]
...@@ -150,7 +77,27 @@ def filter_weight_value(weights, values, valid_mask): ...@@ -150,7 +77,27 @@ def filter_weight_value(weights, values, valid_mask):
return weights, values return weights, values
def project_features(features, bottleneck_dimension, is_training, layer, normalize=True):
def compute_valid_mask(num_valid_elements, num_elements):
"""Computes mask of valid entries within padded context feature.
Args:
num_valid_elements: A int32 Tensor of shape [batch_size].
num_elements: An int32 Tensor.
Returns:
A boolean Tensor of the shape [batch_size, num_elements]. True means
valid and False means invalid.
"""
batch_size = num_valid_elements.shape[0]
element_idxs = tf.range(num_elements, dtype=tf.int32)
batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
num_valid_elements = num_valid_elements[..., tf.newaxis]
valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask
def project_features(features, projection_dimension, is_training, normalize):
"""Projects features to another feature space. """Projects features to another feature space.
Args: Args:
...@@ -158,51 +105,87 @@ def project_features(features, bottleneck_dimension, is_training, layer, normali ...@@ -158,51 +105,87 @@ def project_features(features, bottleneck_dimension, is_training, layer, normali
num_features]. num_features].
projection_dimension: A int32 Tensor. projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
node: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2 normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension. normalized on the last dimension.
Returns: Returns:
A float Tensor of shape [batch, features_size, projection_dimension]. A float Tensor of shape [batch, features_size, projection_dimension].
""" """
shape_arr = features.shape # TODO(guanhangwu) Figure out a better way of specifying the batch norm
batch_size, _, num_features = shape_arr # params.
print("Orig", features.shape) batch_norm_params = {
"is_training": is_training,
"decay": 0.97,
"epsilon": 0.001,
"center": True,
"scale": True
}
batch_size, _, num_features = features.shape
features = tf.reshape(features, [-1, num_features]) features = tf.reshape(features, [-1, num_features])
projected_features = slim.fully_connected(
features,
num_outputs=projection_dimension,
activation_fn=tf.nn.relu6,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params)
projected_features = layer(features, is_training) projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension])
projected_features = tf.reshape(projected_features, [batch_size, -1, bottleneck_dimension])
print(projected_features.shape)
if normalize: if normalize:
projected_features = tf.keras.backend.l2_normalize(projected_features, axis=-1) projected_features = tf.math.l2_normalize(projected_features, axis=-1)
return projected_features return projected_features
def compute_valid_mask(num_valid_elements, num_elements):
"""Computes mask of valid entries within padded context feature. def attention_block(input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask,
Args: is_training):
num_valid_elements: A int32 Tensor of shape [batch_size]. """Generic attention block.
num_elements: An int32 Tensor.
Args:
Returns: input_features: A float Tensor of shape [batch_size, input_size,
A boolean Tensor of the shape [batch_size, num_elements]. True means num_input_features].
valid and False means invalid. context_features: A float Tensor of shape [batch_size, context_size,
""" num_context_features].
batch_size = num_valid_elements.shape[0] bottleneck_dimension: A int32 Tensor representing the bottleneck dimension
element_idxs = tf.range(num_elements, dtype=tf.int32) for intermediate projections.
batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1]) output_dimension: A int32 Tensor representing the last dimension of the
num_valid_elements = num_valid_elements[..., tf.newaxis] output feature.
valid_mask = tf.less(batch_element_idxs, num_valid_elements) attention_temperature: A float Tensor. It controls the temperature of the
return valid_mask softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
valid_mask: A boolean Tensor of shape [batch_size, context_size].
is_training: A boolean Tensor (affecting batch normalization).
Returns:
A float Tensor of shape [batch_size, input_size, output_dimension].
"""
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, bottleneck_dimension, is_training, normalize=True)
keys = project_features(
context_features, bottleneck_dimension, is_training, normalize=True)
values = project_features(
context_features, bottleneck_dimension, is_training, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / attention_temperature)
features = tf.matmul(weights, values)
output_features = project_features(
features, output_dimension, is_training, normalize=False)
return output_features
def compute_box_context_attention(box_features, context_features, def compute_box_context_attention(box_features, context_features,
valid_context_size, bottleneck_dimension, valid_context_size, bottleneck_dimension,
attention_temperature, is_training, attention_temperature, is_training):
freeze_batchnorm, attention_block):
"""Computes the attention feature from the context given a batch of box. """Computes the attention feature from the context given a batch of box.
Args: Args:
...@@ -218,8 +201,6 @@ def compute_box_context_attention(box_features, context_features, ...@@ -218,8 +201,6 @@ def compute_box_context_attention(box_features, context_features,
softmax for weights calculation. The formula for calculation as follows: softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature)) weights = exp(weights / temperature) / sum(exp(weights / temperature))
is_training: A boolean Tensor (affecting batch normalization). is_training: A boolean Tensor (affecting batch normalization).
freeze_batchnorm: Whether to freeze batch normalization weights.
attention_projections: Dictionary of the projection layers.
Returns: Returns:
A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels]. A float Tensor of shape [batch_size, max_num_proposals, 1, 1, channels].
...@@ -228,15 +209,17 @@ def compute_box_context_attention(box_features, context_features, ...@@ -228,15 +209,17 @@ def compute_box_context_attention(box_features, context_features,
valid_mask = compute_valid_mask(valid_context_size, context_size) valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1] channels = box_features.shape[-1]
attention_block.set_output_dimension(channels)
# Average pools over height and width dimension so that the shape of # Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels]. # box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3]) box_features = tf.reduce_mean(box_features, [2, 3])
output_features = attention_block([box_features, context_features], is_training, valid_mask) output_features = attention_block(box_features, context_features,
bottleneck_dimension, channels.value,
attention_temperature, valid_mask,
is_training)
# Expands the dimension back to match with the original feature map. # Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :] output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
return output_features return output_features
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment