Commit 66e8a904 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

make fixes

parent 3475ebda
......@@ -61,10 +61,11 @@ class AttentionBlock(tf.keras.layers.Layer):
def build(self, input_shapes):
self.feature_proj = ContextProjection(input_shapes[0][-1], self.freeze_batchnorm)
self.key_proj.build(input_shapes[0])
self.val_proj.build(input_shapes[0])
self.query_proj.build(input_shapes[0])
self.feature_proj.build(input_shapes[0])
#self.key_proj.build(input_shapes[0])
#self.val_proj.build(input_shapes[0])
#self.query_proj.build(input_shapes[0])
#self.feature_proj.build(input_shapes[0])
pass
def filter_weight_value(self, weights, values, valid_mask):
"""Filters weights and values based on valid_mask.
......@@ -140,7 +141,7 @@ class AttentionBlock(tf.keras.layers.Layer):
projected_features = layer(features, is_training)
projected_features = tf.reshape((batch_size, -1, bottleneck_dimension))(projected_features)
projected_features = tf.reshape(projected_features, [batch_size, -1, bottleneck_dimension])
print(projected_features.shape)
if normalize:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment