Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3fbb55c7
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "99e79054225c4547bb2870526a287320aef0bd32"
Unverified
Commit
3fbb55c7
authored
Sep 15, 2021
by
Bhadresh Savani
Committed by
GitHub
Sep 15, 2021
Browse files
[Flax] Fixes typo in Bart based Flax Models (#13565)
parent
7bd16b87
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
20 deletions
+20
-20
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+4
-4
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+4
-4
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+4
-4
src/transformers/models/pegasus/modeling_flax_pegasus.py
src/transformers/models/pegasus/modeling_flax_pegasus.py
+4
-4
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
...e}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
+4
-4
No files found.
src/transformers/models/bart/modeling_flax_bart.py
View file @
3fbb55c7
...
@@ -406,7 +406,7 @@ class FlaxBartEncoderLayer(nn.Module):
...
@@ -406,7 +406,7 @@ class FlaxBartEncoderLayer(nn.Module):
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
fc1
=
nn
.
Dense
(
self
.
fc1
=
nn
.
Dense
(
self
.
config
.
encoder_ffn_dim
,
self
.
config
.
encoder_ffn_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -433,7 +433,7 @@ class FlaxBartEncoderLayer(nn.Module):
...
@@ -433,7 +433,7 @@ class FlaxBartEncoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -515,7 +515,7 @@ class FlaxBartDecoderLayer(nn.Module):
...
@@ -515,7 +515,7 @@ class FlaxBartDecoderLayer(nn.Module):
)
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
encoder_attn
=
FlaxBartAttention
(
self
.
encoder_attn
=
FlaxBartAttention
(
...
@@ -572,7 +572,7 @@ class FlaxBartDecoderLayer(nn.Module):
...
@@ -572,7 +572,7 @@ class FlaxBartDecoderLayer(nn.Module):
# Fully Connected
# Fully Connected
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
3fbb55c7
...
@@ -411,7 +411,7 @@ class FlaxMarianEncoderLayer(nn.Module):
...
@@ -411,7 +411,7 @@ class FlaxMarianEncoderLayer(nn.Module):
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
fc1
=
nn
.
Dense
(
self
.
fc1
=
nn
.
Dense
(
self
.
config
.
encoder_ffn_dim
,
self
.
config
.
encoder_ffn_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -438,7 +438,7 @@ class FlaxMarianEncoderLayer(nn.Module):
...
@@ -438,7 +438,7 @@ class FlaxMarianEncoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -523,7 +523,7 @@ class FlaxMarianDecoderLayer(nn.Module):
...
@@ -523,7 +523,7 @@ class FlaxMarianDecoderLayer(nn.Module):
)
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
encoder_attn
=
FlaxMarianAttention
(
self
.
encoder_attn
=
FlaxMarianAttention
(
...
@@ -580,7 +580,7 @@ class FlaxMarianDecoderLayer(nn.Module):
...
@@ -580,7 +580,7 @@ class FlaxMarianDecoderLayer(nn.Module):
# Fully Connected
# Fully Connected
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
3fbb55c7
...
@@ -417,7 +417,7 @@ class FlaxMBartEncoderLayer(nn.Module):
...
@@ -417,7 +417,7 @@ class FlaxMBartEncoderLayer(nn.Module):
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
fc1
=
nn
.
Dense
(
self
.
fc1
=
nn
.
Dense
(
self
.
config
.
encoder_ffn_dim
,
self
.
config
.
encoder_ffn_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -444,7 +444,7 @@ class FlaxMBartEncoderLayer(nn.Module):
...
@@ -444,7 +444,7 @@ class FlaxMBartEncoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -527,7 +527,7 @@ class FlaxMBartDecoderLayer(nn.Module):
...
@@ -527,7 +527,7 @@ class FlaxMBartDecoderLayer(nn.Module):
)
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
encoder_attn
=
FlaxMBartAttention
(
self
.
encoder_attn
=
FlaxMBartAttention
(
...
@@ -585,7 +585,7 @@ class FlaxMBartDecoderLayer(nn.Module):
...
@@ -585,7 +585,7 @@ class FlaxMBartDecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
...
src/transformers/models/pegasus/modeling_flax_pegasus.py
View file @
3fbb55c7
...
@@ -411,7 +411,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
...
@@ -411,7 +411,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
fc1
=
nn
.
Dense
(
self
.
fc1
=
nn
.
Dense
(
self
.
config
.
encoder_ffn_dim
,
self
.
config
.
encoder_ffn_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -438,7 +438,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
...
@@ -438,7 +438,7 @@ class FlaxPegasusEncoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -522,7 +522,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
...
@@ -522,7 +522,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
)
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
encoder_attn
=
FlaxPegasusAttention
(
self
.
encoder_attn
=
FlaxPegasusAttention
(
...
@@ -580,7 +580,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
...
@@ -580,7 +580,7 @@ class FlaxPegasusDecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
...
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py
View file @
3fbb55c7
...
@@ -1432,7 +1432,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
...
@@ -1432,7 +1432,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
fc1
=
nn
.
Dense
(
self
.
fc1
=
nn
.
Dense
(
self
.
config
.
encoder_ffn_dim
,
self
.
config
.
encoder_ffn_dim
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -1459,7 +1459,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
...
@@ -1459,7 +1459,7 @@ class Flax{{cookiecutter.camelcase_modelname}}EncoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -1541,7 +1541,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
...
@@ -1541,7 +1541,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
)
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
activation_fn
=
ACT2FN
[
self
.
config
.
activation_function
]
self
.
acti
c
vation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
activation_dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
activation_dropout
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
dtype
=
self
.
dtype
)
self
.
encoder_attn
=
Flax
{{
cookiecutter
.
camelcase_modelname
}}
Attention
(
self
.
encoder_attn
=
Flax
{{
cookiecutter
.
camelcase_modelname
}}
Attention
(
...
@@ -1598,7 +1598,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
...
@@ -1598,7 +1598,7 @@ class Flax{{cookiecutter.camelcase_modelname}}DecoderLayer(nn.Module):
# Fully Connected
# Fully Connected
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
activation_fn
(
self
.
fc1
(
hidden_states
))
hidden_states
=
self
.
acti
c
vation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
activation_dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
self
.
dropout_layer
(
hidden_states
,
deterministic
=
deterministic
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment