Commit 859b92b8 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

move batch norm between dense and activation

parent 0aa87984
...@@ -33,7 +33,6 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -33,7 +33,6 @@ class ContextProjection(tf.keras.layers.Layer):
momentum=0.97, momentum=0.97,
trainable=True) trainable=True)
self.projection = tf.keras.layers.Dense(units=projection_dimension, self.projection = tf.keras.layers.Dense(units=projection_dimension,
activation=tf.nn.relu6,
use_bias=True) use_bias=True)
self.projection_dimension = projection_dimension self.projection_dimension = projection_dimension
super(ContextProjection, self).__init__(**kwargs) super(ContextProjection, self).__init__(**kwargs)
...@@ -43,7 +42,8 @@ class ContextProjection(tf.keras.layers.Layer): ...@@ -43,7 +42,8 @@ class ContextProjection(tf.keras.layers.Layer):
self.batch_norm.build(input_shape[:1] + [self.projection_dimension]) self.batch_norm.build(input_shape[:1] + [self.projection_dimension])
def call(self, input_features, is_training=False): def call(self, input_features, is_training=False):
return self.batch_norm(self.projection(input_features), is_training) return tf.nn.relu6(self.batch_norm(self.projection(input_features),
is_training))
class AttentionBlock(tf.keras.layers.Layer): class AttentionBlock(tf.keras.layers.Layer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment