Unverified Commit a1a67a3c authored by abhishek thakur's avatar abhishek thakur Committed by GitHub
Browse files

Fix GroupedLinearLayer in TF ConvBERT (#9972)

parent 71bdc076
...@@ -435,9 +435,10 @@ class GroupedLinearLayer(tf.keras.layers.Layer): ...@@ -435,9 +435,10 @@ class GroupedLinearLayer(tf.keras.layers.Layer):
) )
def call(self, hidden_states): def call(self, hidden_states):
batch_size = shape_list(tensor=hidden_states)[1] batch_size = shape_list(hidden_states)[0]
x = tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]) x = tf.transpose(tf.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim]), [1, 0, 2])
x = tf.matmul(a=x, b=self.kernel, transpose_b=True) x = tf.matmul(x, self.kernel)
x = tf.transpose(x, [1, 0, 2])
x = tf.reshape(x, [batch_size, -1, self.output_size]) x = tf.reshape(x, [batch_size, -1, self.output_size])
x = tf.nn.bias_add(value=x, bias=self.bias) x = tf.nn.bias_add(value=x, bias=self.bias)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment