Unverified Commit 3fbb55c7 authored by Bhadresh Savani's avatar Bhadresh Savani Committed by GitHub
Browse files

[Flax] Fixes typo in Bart based Flax Models (#13565)

parent 7bd16b87
...@@ -406,7 +406,7 @@ class FlaxBartEncoderLayer(nn.Module): ...@@ -406,7 +406,7 @@ class FlaxBartEncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
...@@ -433,7 +433,7 @@ class FlaxBartEncoderLayer(nn.Module): ...@@ -433,7 +433,7 @@ class FlaxBartEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -515,7 +515,7 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -515,7 +515,7 @@ class FlaxBartDecoderLayer(nn.Module):
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = FlaxBartAttention( self.encoder_attn = FlaxBartAttention(
...@@ -572,7 +572,7 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -572,7 +572,7 @@ class FlaxBartDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
...@@ -411,7 +411,7 @@ class FlaxMarianEncoderLayer(nn.Module): ...@@ -411,7 +411,7 @@ class FlaxMarianEncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
...@@ -438,7 +438,7 @@ class FlaxMarianEncoderLayer(nn.Module): ...@@ -438,7 +438,7 @@ class FlaxMarianEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -523,7 +523,7 @@ class FlaxMarianDecoderLayer(nn.Module): ...@@ -523,7 +523,7 @@ class FlaxMarianDecoderLayer(nn.Module):
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = FlaxMarianAttention( self.encoder_attn = FlaxMarianAttention(
...@@ -580,7 +580,7 @@ class FlaxMarianDecoderLayer(nn.Module): ...@@ -580,7 +580,7 @@ class FlaxMarianDecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
...@@ -417,7 +417,7 @@ class FlaxMBartEncoderLayer(nn.Module): ...@@ -417,7 +417,7 @@ class FlaxMBartEncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
...@@ -444,7 +444,7 @@ class FlaxMBartEncoderLayer(nn.Module): ...@@ -444,7 +444,7 @@ class FlaxMBartEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -527,7 +527,7 @@ class FlaxMBartDecoderLayer(nn.Module): ...@@ -527,7 +527,7 @@ class FlaxMBartDecoderLayer(nn.Module):
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = FlaxMBartAttention( self.encoder_attn = FlaxMBartAttention(
...@@ -585,7 +585,7 @@ class FlaxMBartDecoderLayer(nn.Module): ...@@ -585,7 +585,7 @@ class FlaxMBartDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
...@@ -411,7 +411,7 @@ class FlaxPegasusEncoderLayer(nn.Module): ...@@ -411,7 +411,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
...@@ -438,7 +438,7 @@ class FlaxPegasusEncoderLayer(nn.Module): ...@@ -438,7 +438,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -522,7 +522,7 @@ class FlaxPegasusDecoderLayer(nn.Module): ...@@ -522,7 +522,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = FlaxPegasusAttention( self.encoder_attn = FlaxPegasusAttention(
...@@ -580,7 +580,7 @@ class FlaxPegasusDecoderLayer(nn.Module): ...@@ -580,7 +580,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
...@@ -1432,7 +1432,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): ...@@ -1432,7 +1432,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
self.config.encoder_ffn_dim, self.config.encoder_ffn_dim,
dtype=self.dtype, dtype=self.dtype,
...@@ -1459,7 +1459,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module): ...@@ -1459,7 +1459,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -1541,7 +1541,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1541,7 +1541,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function] self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = Flax{{cookiecutter.camelcase_modelname}}Attention( self.encoder_attn = Flax{{cookiecutter.camelcase_modelname}}Attention(
...@@ -1598,7 +1598,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module): ...@@ -1598,7 +1598,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment