"configs/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "e7d24ac8b87a76d36c1f0e022d450db633e00017"
Unverified Commit e86faecf authored by cloudhan's avatar cloudhan Committed by GitHub
Browse files

Fix obvious typos in flax decoder impl (#17279)

Change config.encoder_ffn_dim -> config.decoder_ffn_dim for decoder.
parent ee393c00
...@@ -537,7 +537,7 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -537,7 +537,7 @@ class FlaxBartDecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
...@@ -528,7 +528,7 @@ class FlaxBlenderbotDecoderLayer(nn.Module): ...@@ -528,7 +528,7 @@ class FlaxBlenderbotDecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
...@@ -541,7 +541,7 @@ class FlaxBlenderbotSmallDecoderLayer(nn.Module): ...@@ -541,7 +541,7 @@ class FlaxBlenderbotSmallDecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
...@@ -551,7 +551,7 @@ class FlaxMarianDecoderLayer(nn.Module): ...@@ -551,7 +551,7 @@ class FlaxMarianDecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
...@@ -550,7 +550,7 @@ class FlaxMBartDecoderLayer(nn.Module): ...@@ -550,7 +550,7 @@ class FlaxMBartDecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
...@@ -544,7 +544,7 @@ class FlaxPegasusDecoderLayer(nn.Module): ...@@ -544,7 +544,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
...@@ -1996,7 +1996,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1996,7 +1996,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.decoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std), kernel_init=jax.nn.initializers.normal(self.config.init_std),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment