Commit eef0a1e7 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finalize context rcnn tf2

parent 12110e64
...@@ -45,8 +45,22 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -45,8 +45,22 @@ class ContextProjection(tf.keras.layers.Layer):
class AttentionBlock(tf.keras.layers.Layer): class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention.""" """Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature, def __init__(self, bottleneck_dimension, attention_temperature,
output_dimension=None, output_dimension=None, is_training=False,
is_training=False, name='AttentionBlock', **kwargs): name='AttentionBlock', **kwargs):
"""Constructs an attention block.
Args:
bottleneck_dimension: A int32 Tensor representing the bottleneck dimension
for intermediate projections.
attention_temperature: A float Tensor. It controls the temperature of the
softmax for weights calculation. The formula for calculation as follows:
weights = exp(weights / temperature) / sum(exp(weights / temperature))
output_dimension: A int32 Tensor representing the last dimension of the
output feature.
is_training: A boolean Tensor (affecting batch normalization).
name: A string describing what to name the variables in this block.
"""
self._key_proj = ContextProjection(bottleneck_dimension) self._key_proj = ContextProjection(bottleneck_dimension)
self._val_proj = ContextProjection(bottleneck_dimension) self._val_proj = ContextProjection(bottleneck_dimension)
self._query_proj = ContextProjection(bottleneck_dimension) self._query_proj = ContextProjection(bottleneck_dimension)
...@@ -60,12 +74,29 @@ class AttentionBlock(tf.keras.layers.Layer): ...@@ -60,12 +74,29 @@ class AttentionBlock(tf.keras.layers.Layer):
super(AttentionBlock, self).__init__(name=name, **kwargs) super(AttentionBlock, self).__init__(name=name, **kwargs)
def build(self, input_shapes): def build(self, input_shapes):
"""Finishes building the attention block.
Args:
input_shapes: the shape of the primary input box features.
"""
if not self._feature_proj: if not self._feature_proj:
self._output_dimension = input_shapes[-1] self._output_dimension = input_shapes[-1]
self._feature_proj = ContextProjection(self._output_dimension) self._feature_proj = ContextProjection(self._output_dimension)
def call(self, box_features, context_features, valid_context_size): def call(self, box_features, context_features, valid_context_size):
"""Handles a call by performing attention.""" """Handles a call by performing attention.
Args:
box_features: A float Tensor of shape [batch_size, input_size,
num_input_features].
context_features: A float Tensor of shape [batch_size, context_size,
num_context_features].
valid_context_size: A int32 Tensor of shape [batch_size].
Returns:
A float Tensor with shape [batch_size, input_size, num_input_features]
containing output features after attention with context features.
"""
_, context_size, _ = context_features.shape _, context_size, _ = context_features.shape
valid_mask = compute_valid_mask(valid_context_size, context_size) valid_mask = compute_valid_mask(valid_context_size, context_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment