"examples/vscode:/vscode.git/clone" did not exist on "5105b5a83d04323dc583846a12be054e3701c4ed"
Commit 3475ebda authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

progress on integrating bettter

parent f8df8742
...@@ -49,162 +49,153 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -49,162 +49,153 @@ class ContextProjection(tf.keras.layers.Layer):
def call(self, input_features, is_training=False): def call(self, input_features, is_training=False):
return self.projection(self.batch_norm(input_features, is_training)) return self.projection(self.batch_norm(input_features, is_training))
class AttentionBlock(tf.keras.layers.Layer):
def filter_weight_value(weights, values, valid_mask): def __init__(self, bottleneck_dimension, attention_temperature, freeze_batchnorm, **kwargs):
"""Filters weights and values based on valid_mask. self.key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self.val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
_NEGATIVE_PADDING_VALUE will be added to invalid elements in the weights to self.query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
avoid their contribution in softmax. 0 will be set for the invalid elements in self.attention_temperature = attention_temperature
the values. self.freeze_batchnorm = freeze_batchnorm
self.bottleneck_dimension = bottleneck_dimension
Args: super(AttentionBlock, self).__init__(**kwargs)
weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size, def build(self, input_shapes):
projected_dimension]. self.feature_proj = ContextProjection(input_shapes[0][-1], self.freeze_batchnorm)
valid_mask: A boolean Tensor of shape [batch_size, context_size]. True means self.key_proj.build(input_shapes[0])
valid and False means invalid. self.val_proj.build(input_shapes[0])
self.query_proj.build(input_shapes[0])
Returns: self.feature_proj.build(input_shapes[0])
weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size, def filter_weight_value(self, weights, values, valid_mask):
projected_dimension]. """Filters weights and values based on valid_mask.
Raises: _NEGATIVE_PADDING_VALUE will be added to invalid elements in the weights to
ValueError: If shape of doesn't match. avoid their contribution in softmax. 0 will be set for the invalid elements in
""" the values.
w_batch_size, _, w_context_size = weights.shape
v_batch_size, v_context_size, _ = values.shape Args:
m_batch_size, m_context_size = valid_mask.shape weights: A float Tensor of shape [batch_size, input_size, context_size].
if w_batch_size != v_batch_size or v_batch_size != m_batch_size: values: A float Tensor of shape [batch_size, context_size,
raise ValueError("Please make sure the first dimension of the input" projected_dimension].
" tensors are the same.") valid_mask: A boolean Tensor of shape [batch_size, context_size]. True means
valid and False means invalid.
if w_context_size != v_context_size:
raise ValueError("Please make sure the third dimension of weights matches" Returns:
" the second dimension of values.") weights: A float Tensor of shape [batch_size, input_size, context_size].
values: A float Tensor of shape [batch_size, context_size,
if w_context_size != m_context_size: projected_dimension].
raise ValueError("Please make sure the third dimension of the weights"
" matches the second dimension of the valid_mask.") Raises:
ValueError: If shape of doesn't match.
valid_mask = valid_mask[..., tf.newaxis] """
w_batch_size, _, w_context_size = weights.shape
# Force the invalid weights to be very negative so it won't contribute to v_batch_size, v_context_size, _ = values.shape
# the softmax. m_batch_size, m_context_size = valid_mask.shape
weights += tf.transpose( if w_batch_size != v_batch_size or v_batch_size != m_batch_size:
tf.cast(tf.math.logical_not(valid_mask), weights.dtype) * raise ValueError("Please make sure the first dimension of the input"
_NEGATIVE_PADDING_VALUE, " tensors are the same.")
perm=[0, 2, 1])
if w_context_size != v_context_size:
# Force the invalid values to be 0. raise ValueError("Please make sure the third dimension of weights matches"
values *= tf.cast(valid_mask, values.dtype) " the second dimension of values.")
return weights, values if w_context_size != m_context_size:
raise ValueError("Please make sure the third dimension of the weights"
" matches the second dimension of the valid_mask.")
valid_mask = valid_mask[..., tf.newaxis]
# Force the invalid weights to be very negative so it won't contribute to
# the softmax.
weights += tf.transpose(
tf.cast(tf.math.logical_not(valid_mask), weights.dtype) *
_NEGATIVE_PADDING_VALUE,
perm=[0, 2, 1])
# Force the invalid values to be 0.
values *= tf.cast(valid_mask, values.dtype)
return weights, values
def run_projection(self, features, bottleneck_dimension, is_training, layer, normalize=True):
"""Projects features to another feature space.
Args:
features: A float Tensor of shape [batch_size, features_size,
num_features].
projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
node: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
Returns:
A float Tensor of shape [batch, features_size, projection_dimension].
"""
shape_arr = features.shape
batch_size, _, num_features = shape_arr
print("Orig", features.shape)
features = tf.reshape(features, [-1, num_features])
projected_features = layer(features, is_training)
projected_features = tf.reshape((batch_size, -1, bottleneck_dimension))(projected_features)
print(projected_features.shape)
if normalize:
projected_features = tf.keras.backend.l2_normalize(projected_features, axis=-1)
return projected_features
def call(self, input_features, is_training, valid_mask):
input_features, context_features = input_features
with tf.variable_scope("AttentionBlock"):
queries = self.run_projection(
input_features, self.bottleneck_dimension, is_training,
self.query_proj, normalize=True)
keys = self.run_projection(
context_features, self.bottleneck_dimension, is_training,
self.key_proj, normalize=True)
values = self.run_projection(
context_features, self.bottleneck_dimension, is_training,
self.val_proj, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = self.filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / self.attention_temperature)
features = tf.matmul(weights, values)
output_features = self.run_projection(
features, input_features.shape[-1], is_training,
self.feature_proj, normalize=False)
return output_features
def compute_valid_mask(num_valid_elements, num_elements): def compute_valid_mask(num_valid_elements, num_elements):
"""Computes mask of valid entries within padded context feature. """Computes mask of valid entries within padded context feature.
Args: Args:
num_valid_elements: A int32 Tensor of shape [batch_size]. num_valid_elements: A int32 Tensor of shape [batch_size].
num_elements: An int32 Tensor. num_elements: An int32 Tensor.
Returns: Returns:
A boolean Tensor of the shape [batch_size, num_elements]. True means A boolean Tensor of the shape [batch_size, num_elements]. True means
valid and False means invalid. valid and False means invalid.
""" """
batch_size = num_valid_elements.shape[0] batch_size = num_valid_elements.shape[0]
element_idxs = tf.range(num_elements, dtype=tf.int32) element_idxs = tf.range(num_elements, dtype=tf.int32)
batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1]) batch_element_idxs = tf.tile(element_idxs[tf.newaxis, ...], [batch_size, 1])
num_valid_elements = num_valid_elements[..., tf.newaxis] num_valid_elements = num_valid_elements[..., tf.newaxis]
valid_mask = tf.less(batch_element_idxs, num_valid_elements) valid_mask = tf.less(batch_element_idxs, num_valid_elements)
return valid_mask return valid_mask
def project_features(features, projection_dimension,
is_training, node, normalize=True):
"""Projects features to another feature space.
Args:
features: A float Tensor of shape [batch_size, features_size,
num_features].
projection_dimension: A int32 Tensor.
is_training: A boolean Tensor (affecting batch normalization).
node: Contains a custom layer specific to the particular operation
being performed (key, value, query, features)
normalize: A boolean Tensor. If true, the output features will be l2
normalized on the last dimension.
Returns:
A float Tensor of shape [batch, features_size, projection_dimension].
"""
shape_arr = features.shape
batch_size, _, num_features = shape_arr
features = tf.reshape(features, [-1, num_features])
projected_features = node(features, is_training)
projected_features = tf.reshape(projected_features,
[batch_size, -1, projection_dimension])
if normalize:
projected_features = tf.math.l2_normalize(projected_features, axis=-1)
return projected_features
def attention_block(input_features, context_features, bottleneck_dimension,
output_dimension, attention_temperature, valid_mask,
is_training, attention_projections):
"""Generic attention block.
Args:
input_features: A float Tensor of shape [batch_size, input_size,
num_input_features].
context_features: A float Tensor of shape [batch_size, context_size,
num_context_features].
bottleneck_dimension: A int32 Tensor representing the bottleneck dimension
for intermediate projections.
output_dimension: A int32 Tensor representing the last dimension of the
output feature.
attention_temperature: A float Tensor. It controls the temperature of the
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
valid_mask: A boolean Tensor of shape [batch_size, context_size].
is_training: A boolean Tensor (affecting batch normalization).
attention_projections: Contains a dictionary of the projection objects.
Returns:
A float Tensor of shape [batch_size, input_size, output_dimension].
"""
with tf.variable_scope("AttentionBlock"):
queries = project_features(
input_features, bottleneck_dimension, is_training,
attention_projections[QUERY_NAME], normalize=True)
keys = project_features(
context_features, bottleneck_dimension, is_training,
attention_projections[KEY_NAME], normalize=True)
values = project_features(
context_features, bottleneck_dimension, is_training,
attention_projections[VALUE_NAME], normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask)
weights = tf.nn.softmax(weights / attention_temperature)
features = tf.matmul(weights, values)
output_features = project_features(
features, output_dimension, is_training,
attention_projections[FEATURE_NAME], normalize=False)
return output_features
def compute_box_context_attention(box_features, context_features, def compute_box_context_attention(box_features, context_features,
valid_context_size, bottleneck_dimension, valid_context_size, bottleneck_dimension,
attention_temperature, is_training, attention_temperature, is_training,
freeze_batchnorm, attention_projections): freeze_batchnorm, attention_block):
"""Computes the attention feature from the context given a batch of box. """Computes the attention feature from the context given a batch of box.
Args: Args:
...@@ -230,19 +221,12 @@ def compute_box_context_attention(box_features, context_features, ...@@ -230,19 +221,12 @@ def compute_box_context_attention(box_features, context_features,
valid_mask = compute_valid_mask(valid_context_size, context_size) valid_mask = compute_valid_mask(valid_context_size, context_size)
channels = box_features.shape[-1] channels = box_features.shape[-1]
if 'feature' not in attention_projections:
attention_projections[FEATURE_NAME] = ContextProjection(
channels, freeze_batchnorm)
# Average pools over height and width dimension so that the shape of # Average pools over height and width dimension so that the shape of
# box_features becomes [batch_size, max_num_proposals, channels]. # box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3]) box_features = tf.reduce_mean(box_features, [2, 3])
output_features = attention_block(box_features, context_features, output_features = attention_block([box_features, context_features], is_training, valid_mask)
bottleneck_dimension, channels,
attention_temperature, valid_mask,
is_training, attention_projections)
# Expands the dimension back to match with the original feature map. # Expands the dimension back to match with the original feature map.
output_features = output_features[:, :, tf.newaxis, tf.newaxis, :] output_features = output_features[:, :, tf.newaxis, tf.newaxis, :]
......
...@@ -272,12 +272,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -272,12 +272,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
is_training=is_training, is_training=is_training,
freeze_batchnorm=freeze_batchnorm) freeze_batchnorm=freeze_batchnorm)
self._atten_projs = {'key': context_rcnn_lib.ContextProjection( self._attention_block = context_rcnn_lib.AttentionBlock(attention_bottleneck_dimension, attention_temperature, freeze_batchnorm)
attention_bottleneck_dimension, freeze_batchnorm),
'val': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm),
'query': context_rcnn_lib.ContextProjection(
attention_bottleneck_dimension, freeze_batchnorm)}
@staticmethod @staticmethod
def get_side_inputs(features): def get_side_inputs(features):
...@@ -340,7 +335,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -340,7 +335,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
box_features=box_features, box_features=box_features,
context_features=context_features, context_features=context_features,
valid_context_size=valid_context_size, valid_context_size=valid_context_size,
attention_projections=self._atten_projs) attention_block=self._attention_block)
# Adds box features with attention features. # Adds box features with attention features.
box_features += attention_features box_features += attention_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment