Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2596f95e
Unverified
Commit
2596f95e
authored
Mar 07, 2022
by
Sanchit Gandhi
Committed by
GitHub
Mar 07, 2022
Browse files
Fix Embedding Module Bug in Flax Models (#15920)
parent
1a62b25c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
13 additions
and
104 deletions
+13
-104
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+2
-16
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
...ransformers/models/blenderbot/modeling_flax_blenderbot.py
+2
-16
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
...models/blenderbot_small/modeling_flax_blenderbot_small.py
+2
-16
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+2
-16
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+2
-16
src/transformers/models/pegasus/modeling_flax_pegasus.py
src/transformers/models/pegasus/modeling_flax_pegasus.py
+2
-16
src/transformers/models/t5/modeling_flax_t5.py
src/transformers/models/t5/modeling_flax_t5.py
+1
-8
No files found.
src/transformers/models/bart/modeling_flax_bart.py
View file @
2596f95e
...
@@ -697,8 +697,8 @@ class FlaxBartClassificationHead(nn.Module):
...
@@ -697,8 +697,8 @@ class FlaxBartClassificationHead(nn.Module):
class
FlaxBartEncoder
(
nn
.
Module
):
class
FlaxBartEncoder
(
nn
.
Module
):
config
:
BartConfig
config
:
BartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -708,13 +708,6 @@ class FlaxBartEncoder(nn.Module):
...
@@ -708,13 +708,6 @@ class FlaxBartEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
self
.
offset
=
2
...
@@ -768,8 +761,8 @@ class FlaxBartEncoder(nn.Module):
...
@@ -768,8 +761,8 @@ class FlaxBartEncoder(nn.Module):
class
FlaxBartDecoder
(
nn
.
Module
):
class
FlaxBartDecoder
(
nn
.
Module
):
config
:
BartConfig
config
:
BartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -779,13 +772,6 @@ class FlaxBartDecoder(nn.Module):
...
@@ -779,13 +772,6 @@ class FlaxBartDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
self
.
offset
=
2
...
...
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
View file @
2596f95e
...
@@ -661,8 +661,8 @@ class FlaxBlenderbotDecoderLayerCollection(nn.Module):
...
@@ -661,8 +661,8 @@ class FlaxBlenderbotDecoderLayerCollection(nn.Module):
class
FlaxBlenderbotEncoder
(
nn
.
Module
):
class
FlaxBlenderbotEncoder
(
nn
.
Module
):
config
:
BlenderbotConfig
config
:
BlenderbotConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -672,13 +672,6 @@ class FlaxBlenderbotEncoder(nn.Module):
...
@@ -672,13 +672,6 @@ class FlaxBlenderbotEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
,
embed_dim
,
embed_dim
,
...
@@ -730,8 +723,8 @@ class FlaxBlenderbotEncoder(nn.Module):
...
@@ -730,8 +723,8 @@ class FlaxBlenderbotEncoder(nn.Module):
class
FlaxBlenderbotDecoder
(
nn
.
Module
):
class
FlaxBlenderbotDecoder
(
nn
.
Module
):
config
:
BlenderbotConfig
config
:
BlenderbotConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -741,13 +734,6 @@ class FlaxBlenderbotDecoder(nn.Module):
...
@@ -741,13 +734,6 @@ class FlaxBlenderbotDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
,
embed_dim
,
embed_dim
,
...
...
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
View file @
2596f95e
...
@@ -674,8 +674,8 @@ class FlaxBlenderbotSmallDecoderLayerCollection(nn.Module):
...
@@ -674,8 +674,8 @@ class FlaxBlenderbotSmallDecoderLayerCollection(nn.Module):
class
FlaxBlenderbotSmallEncoder
(
nn
.
Module
):
class
FlaxBlenderbotSmallEncoder
(
nn
.
Module
):
config
:
BlenderbotSmallConfig
config
:
BlenderbotSmallConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -685,13 +685,6 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
...
@@ -685,13 +685,6 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
,
embed_dim
,
embed_dim
,
...
@@ -742,8 +735,8 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
...
@@ -742,8 +735,8 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
class
FlaxBlenderbotSmallDecoder
(
nn
.
Module
):
class
FlaxBlenderbotSmallDecoder
(
nn
.
Module
):
config
:
BlenderbotSmallConfig
config
:
BlenderbotSmallConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -753,13 +746,6 @@ class FlaxBlenderbotSmallDecoder(nn.Module):
...
@@ -753,13 +746,6 @@ class FlaxBlenderbotSmallDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
max_position_embeddings
,
embed_dim
,
embed_dim
,
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
2596f95e
...
@@ -684,8 +684,8 @@ class FlaxMarianDecoderLayerCollection(nn.Module):
...
@@ -684,8 +684,8 @@ class FlaxMarianDecoderLayerCollection(nn.Module):
class
FlaxMarianEncoder
(
nn
.
Module
):
class
FlaxMarianEncoder
(
nn
.
Module
):
config
:
MarianConfig
config
:
MarianConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -694,13 +694,6 @@ class FlaxMarianEncoder(nn.Module):
...
@@ -694,13 +694,6 @@ class FlaxMarianEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
FlaxMarianEncoderLayerCollection
(
self
.
config
,
self
.
dtype
)
self
.
layers
=
FlaxMarianEncoderLayerCollection
(
self
.
config
,
self
.
dtype
)
...
@@ -747,8 +740,8 @@ class FlaxMarianEncoder(nn.Module):
...
@@ -747,8 +740,8 @@ class FlaxMarianEncoder(nn.Module):
class
FlaxMarianDecoder
(
nn
.
Module
):
class
FlaxMarianDecoder
(
nn
.
Module
):
config
:
MarianConfig
config
:
MarianConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -757,13 +750,6 @@ class FlaxMarianDecoder(nn.Module):
...
@@ -757,13 +750,6 @@ class FlaxMarianDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
FlaxMarianDecoderLayerCollection
(
self
.
config
,
self
.
dtype
)
self
.
layers
=
FlaxMarianDecoderLayerCollection
(
self
.
config
,
self
.
dtype
)
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
2596f95e
...
@@ -712,8 +712,8 @@ class FlaxMBartClassificationHead(nn.Module):
...
@@ -712,8 +712,8 @@ class FlaxMBartClassificationHead(nn.Module):
class
FlaxMBartEncoder
(
nn
.
Module
):
class
FlaxMBartEncoder
(
nn
.
Module
):
config
:
MBartConfig
config
:
MBartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -723,13 +723,6 @@ class FlaxMBartEncoder(nn.Module):
...
@@ -723,13 +723,6 @@ class FlaxMBartEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
self
.
offset
=
2
...
@@ -787,8 +780,8 @@ class FlaxMBartEncoder(nn.Module):
...
@@ -787,8 +780,8 @@ class FlaxMBartEncoder(nn.Module):
class
FlaxMBartDecoder
(
nn
.
Module
):
class
FlaxMBartDecoder
(
nn
.
Module
):
config
:
MBartConfig
config
:
MBartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -798,13 +791,6 @@ class FlaxMBartDecoder(nn.Module):
...
@@ -798,13 +791,6 @@ class FlaxMBartDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
self
.
offset
=
2
...
...
src/transformers/models/pegasus/modeling_flax_pegasus.py
View file @
2596f95e
...
@@ -677,8 +677,8 @@ class FlaxPegasusDecoderLayerCollection(nn.Module):
...
@@ -677,8 +677,8 @@ class FlaxPegasusDecoderLayerCollection(nn.Module):
class
FlaxPegasusEncoder
(
nn
.
Module
):
class
FlaxPegasusEncoder
(
nn
.
Module
):
config
:
PegasusConfig
config
:
PegasusConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -688,13 +688,6 @@ class FlaxPegasusEncoder(nn.Module):
...
@@ -688,13 +688,6 @@ class FlaxPegasusEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
)
)
...
@@ -746,8 +739,8 @@ class FlaxPegasusEncoder(nn.Module):
...
@@ -746,8 +739,8 @@ class FlaxPegasusEncoder(nn.Module):
class
FlaxPegasusDecoder
(
nn
.
Module
):
class
FlaxPegasusDecoder
(
nn
.
Module
):
config
:
PegasusConfig
config
:
PegasusConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
@@ -757,13 +750,6 @@ class FlaxPegasusDecoder(nn.Module):
...
@@ -757,13 +750,6 @@ class FlaxPegasusDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
)
)
...
...
src/transformers/models/t5/modeling_flax_t5.py
View file @
2596f95e
...
@@ -709,19 +709,12 @@ class FlaxT5BlockCollection(nn.Module):
...
@@ -709,19 +709,12 @@ class FlaxT5BlockCollection(nn.Module):
class
FlaxT5Stack
(
nn
.
Module
):
class
FlaxT5Stack
(
nn
.
Module
):
config
:
T5Config
config
:
T5Config
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
def
setup
(
self
):
def
setup
(
self
):
self
.
causal
=
self
.
config
.
causal
self
.
causal
=
self
.
config
.
causal
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
self
.
config
.
d_model
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
block
=
FlaxT5BlockCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
self
.
block
=
FlaxT5BlockCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
self
.
final_layer_norm
=
FlaxT5LayerNorm
(
self
.
final_layer_norm
=
FlaxT5LayerNorm
(
self
.
config
.
d_model
,
eps
=
self
.
config
.
layer_norm_epsilon
,
dtype
=
self
.
dtype
self
.
config
.
d_model
,
eps
=
self
.
config
.
layer_norm_epsilon
,
dtype
=
self
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment