Commit 0aa87984 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

context R-CNN bugfix and documentation fix

parent 88253ce5
......@@ -35,14 +35,15 @@ class ContextProjection(tf.keras.layers.Layer):
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
self.projection_dimension = projection_dimension
super(ContextProjection, self).__init__(**kwargs)
def build(self, input_shape):
self.batch_norm.build(input_shape)
self.projection.build(input_shape)
self.batch_norm.build(input_shape[:1] + [self.projection_dimension])
def call(self, input_features, is_training=False):
return self.projection(self.batch_norm(input_features, is_training))
return self.batch_norm(self.projection(input_features), is_training)
class AttentionBlock(tf.keras.layers.Layer):
......@@ -92,8 +93,8 @@ class AttentionBlock(tf.keras.layers.Layer):
"""Handles a call by performing attention.
Args:
box_features: A float Tensor of shape [batch_size, input_size,
num_input_features].
box_features: A float Tensor of shape [batch_size, input_size, height,
width, num_input_features].
context_features: A float Tensor of shape [batch_size, context_size,
num_context_features].
valid_context_size: A int32 Tensor of shape [batch_size].
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment