Commit f61357cd authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

make suggested fixes

parent cbd607ab
...@@ -99,12 +99,16 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase, ...@@ -99,12 +99,16 @@ class ContextRcnnLibTest(parameterized.TestCase, test_case.TestCase,
input_features = tf.ones([2, 8, 3, 3, 3], tf.float32) input_features = tf.ones([2, 8, 3, 3, 3], tf.float32)
context_features = tf.ones([2, 20, 10], tf.float32) context_features = tf.ones([2, 20, 10], tf.float32)
is_training = False is_training = False
attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension, attention_temperature, False, output_dimension) attention_block = context_rcnn_lib.AttentionBlock(bottleneck_dimension,
attention_temperature,
freeze_batchnorm=False,
output_dimension=output_dimension,
is_training=False)
valid_context_size = tf.random_uniform((2,), valid_context_size = tf.random_uniform((2,),
minval=0, minval=0,
maxval=10, maxval=10,
dtype=tf.int32) dtype=tf.int32)
output_features = attention_block([input_features, context_features], is_training, valid_context_size) output_features = attention_block(input_features, context_features, valid_context_size)
# Makes sure the shape is correct. # Makes sure the shape is correct.
self.assertAllEqual(output_features.shape, [2, 8, 1, 1, output_dimension]) self.assertAllEqual(output_features.shape, [2, 8, 1, 1, output_dimension])
......
...@@ -52,7 +52,7 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -52,7 +52,7 @@ class ContextProjection(tf.keras.layers.Layer):
class AttentionBlock(tf.keras.layers.Layer): class AttentionBlock(tf.keras.layers.Layer):
"""Custom layer to perform all attention.""" """Custom layer to perform all attention."""
def __init__(self, bottleneck_dimension, attention_temperature, def __init__(self, bottleneck_dimension, attention_temperature,
freeze_batchnorm, output_dimension=None, **kwargs): freeze_batchnorm, output_dimension=None, is_training=False, **kwargs):
self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._key_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._val_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm) self._query_proj = ContextProjection(bottleneck_dimension, freeze_batchnorm)
......
...@@ -301,7 +301,6 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -301,7 +301,6 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
"Please make sure context_features and valid_context_size are in the " "Please make sure context_features and valid_context_size are in the "
"features") "features")
print("In get side inputs, returning side features.")
return { return {
fields.InputDataFields.context_features: fields.InputDataFields.context_features:
features[fields.InputDataFields.context_features], features[fields.InputDataFields.context_features],
...@@ -338,8 +337,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -338,8 +337,7 @@ class ContextRCNNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
[self._initial_crop_size, self._initial_crop_size]) [self._initial_crop_size, self._initial_crop_size])
attention_features = self._context_feature_extract_fn( attention_features = self._context_feature_extract_fn(
box_features=box_features, box_features, context_features,
context_features=context_features,
valid_context_size=valid_context_size) valid_context_size=valid_context_size)
# Adds box features with attention features. # Adds box features with attention features.
......
...@@ -518,8 +518,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -518,8 +518,6 @@ class ContextRCNNMetaArchTest(test_case.TestCase, parameterized.TestCase):
} }
side_inputs = model.get_side_inputs(features) side_inputs = model.get_side_inputs(features)
print('preprocessed', preprocessed_inputs.shape)
print('context', context_features.shape)
prediction_dict = model.predict(preprocessed_inputs, true_image_shapes, prediction_dict = model.predict(preprocessed_inputs, true_image_shapes,
**side_inputs) **side_inputs)
return (prediction_dict['rpn_box_predictor_features'], return (prediction_dict['rpn_box_predictor_features'],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment