Commit 8b6a4628 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

finalize context rcnn tf2 support

parent 6d39f5e6
...@@ -73,16 +73,15 @@ class AttentionBlock(tf.keras.layers.Layer): ...@@ -73,16 +73,15 @@ class AttentionBlock(tf.keras.layers.Layer):
# box_features becomes [batch_size, max_num_proposals, channels]. # box_features becomes [batch_size, max_num_proposals, channels].
box_features = tf.reduce_mean(box_features, [2, 3]) box_features = tf.reduce_mean(box_features, [2, 3])
with tf.name_scope("AttentionBlock"): queries = project_features(
queries = project_features( box_features, self._bottleneck_dimension, self._is_training,
box_features, self._bottleneck_dimension, self._is_training, self._query_proj, normalize=True)
self._query_proj, normalize=True) keys = project_features(
keys = project_features( context_features, self._bottleneck_dimension, self._is_training,
context_features, self._bottleneck_dimension, self._is_training, self._key_proj, normalize=True)
self._key_proj, normalize=True) values = project_features(
values = project_features( context_features, self._bottleneck_dimension, self._is_training,
context_features, self._bottleneck_dimension, self._is_training, self._val_proj, normalize=True)
self._val_proj, normalize=True)
weights = tf.matmul(queries, keys, transpose_b=True) weights = tf.matmul(queries, keys, transpose_b=True)
weights, values = filter_weight_value(weights, values, valid_mask) weights, values = filter_weight_value(weights, values, valid_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment