Unverified Commit df1f94eb authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[TFWav2Vec2Model] Fix input shapes in TFWav2Vec2WeightNormConv1D (#14319)

* Add paddings to input shapes

* Add padding comment
parent e30078b5
......@@ -525,7 +525,11 @@ class TFHubertWeightNormConv1D(tf.keras.layers.Conv1D):
def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# Conv1D output shapes are checked at build time since TF 2.7, so we need to account for padding
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)
self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
self.weight_v = self.kernel
......
......@@ -524,7 +524,11 @@ class TFWav2Vec2WeightNormConv1D(tf.keras.layers.Conv1D):
def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# Conv1D output shapes are checked at build time since TF 2.7, so we need to account for padding
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)
self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
self.weight_v = self.kernel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment