Commit 859b92b8 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

move batch norm between dense and activation

parent 0aa87984
......@@ -33,7 +33,6 @@ class ContextProjection(tf.keras.layers.Layer):
momentum=0.97,
trainable=True)
self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True)
self.projection_dimension = projection_dimension
super(ContextProjection, self).__init__(**kwargs)
......@@ -43,7 +42,8 @@ class ContextProjection(tf.keras.layers.Layer):
self.batch_norm.build(input_shape[:1] + [self.projection_dimension])
def call(self, input_features, is_training=False):
return self.batch_norm(self.projection(input_features), is_training)
return tf.nn.relu6(self.batch_norm(self.projection(input_features),
is_training))
class AttentionBlock(tf.keras.layers.Layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment