"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c796b6dea6cd3a887b4da0aaa54026f6472b0a98"
Unverified Commit 8565d38f authored by Kamal Raj's avatar Kamal Raj Committed by GitHub
Browse files

Update modeling_flax_wav2vec2.py (#13680)

conv kernel_size to Tuple,
Flax Version 0.3.5 breaking change, https://github.com/google/flax/releases/tag/v0.3.5
parent d16bec95
...@@ -286,7 +286,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module): ...@@ -286,7 +286,7 @@ class FlaxWav2Vec2LayerNormConvLayer(nn.Module):
self.conv = nn.Conv( self.conv = nn.Conv(
features=self.config.conv_dim[self.layer_id], features=self.config.conv_dim[self.layer_id],
kernel_size=self.config.conv_kernel[self.layer_id], kernel_size=(self.config.conv_kernel[self.layer_id],),
strides=(self.config.conv_stride[self.layer_id],), strides=(self.config.conv_stride[self.layer_id],),
use_bias=self.config.conv_bias, use_bias=self.config.conv_bias,
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
...@@ -310,7 +310,7 @@ class FlaxConvWithWeightNorm(nn.Module): ...@@ -310,7 +310,7 @@ class FlaxConvWithWeightNorm(nn.Module):
def setup(self): def setup(self):
self.conv = nn.Conv( self.conv = nn.Conv(
features=self.config.hidden_size, features=self.config.hidden_size,
kernel_size=self.config.num_conv_pos_embeddings, kernel_size=(self.config.num_conv_pos_embeddings,),
kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
padding="VALID", padding="VALID",
feature_group_count=self.config.num_conv_pos_embedding_groups, feature_group_count=self.config.num_conv_pos_embedding_groups,
...@@ -319,12 +319,12 @@ class FlaxConvWithWeightNorm(nn.Module): ...@@ -319,12 +319,12 @@ class FlaxConvWithWeightNorm(nn.Module):
weight_shape = ( weight_shape = (
self.conv.features, self.conv.features,
self.conv.features // self.conv.feature_group_count, self.conv.features // self.conv.feature_group_count,
self.conv.kernel_size, self.conv.kernel_size[0],
) )
self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape) self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(dtype=self.dtype), weight_shape)
self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :])
self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,))
self.prev_padding = self.conv.kernel_size // 2 self.prev_padding = self.conv.kernel_size[0] // 2
def _get_normed_weights(self): def _get_normed_weights(self):
weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment