Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2596f95e
Unverified
Commit
2596f95e
authored
Mar 07, 2022
by
Sanchit Gandhi
Committed by
GitHub
Mar 07, 2022
Browse files
Fix Embedding Module Bug in Flax Models (#15920)
parent
1a62b25c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
13 additions
and
104 deletions
+13
-104
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+2
-16
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
...ransformers/models/blenderbot/modeling_flax_blenderbot.py
+2
-16
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
...models/blenderbot_small/modeling_flax_blenderbot_small.py
+2
-16
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+2
-16
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+2
-16
src/transformers/models/pegasus/modeling_flax_pegasus.py
src/transformers/models/pegasus/modeling_flax_pegasus.py
+2
-16
src/transformers/models/t5/modeling_flax_t5.py
src/transformers/models/t5/modeling_flax_t5.py
+1
-8
No files found.
src/transformers/models/bart/modeling_flax_bart.py
View file @
2596f95e
...
...
@@ -697,8 +697,8 @@ class FlaxBartClassificationHead(nn.Module):
class
FlaxBartEncoder
(
nn
.
Module
):
config
:
BartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -708,13 +708,6 @@ class FlaxBartEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
@@ -768,8 +761,8 @@ class FlaxBartEncoder(nn.Module):
class
FlaxBartDecoder
(
nn
.
Module
):
config
:
BartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -779,13 +772,6 @@ class FlaxBartDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
View file @
2596f95e
...
...
@@ -661,8 +661,8 @@ class FlaxBlenderbotDecoderLayerCollection(nn.Module):
class
FlaxBlenderbotEncoder
(
nn
.
Module
):
config
:
BlenderbotConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -672,13 +672,6 @@ class FlaxBlenderbotEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
@@ -730,8 +723,8 @@ class FlaxBlenderbotEncoder(nn.Module):
class
FlaxBlenderbotDecoder
(
nn
.
Module
):
config
:
BlenderbotConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -741,13 +734,6 @@ class FlaxBlenderbotDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
View file @
2596f95e
...
...
@@ -674,8 +674,8 @@ class FlaxBlenderbotSmallDecoderLayerCollection(nn.Module):
class
FlaxBlenderbotSmallEncoder
(
nn
.
Module
):
config
:
BlenderbotSmallConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -685,13 +685,6 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
@@ -742,8 +735,8 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
class
FlaxBlenderbotSmallDecoder
(
nn
.
Module
):
config
:
BlenderbotSmallConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -753,13 +746,6 @@ class FlaxBlenderbotSmallDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
2596f95e
...
...
@@ -684,8 +684,8 @@ class FlaxMarianDecoderLayerCollection(nn.Module):
class
FlaxMarianEncoder
(
nn
.
Module
):
config
:
MarianConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -694,13 +694,6 @@ class FlaxMarianEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
FlaxMarianEncoderLayerCollection
(
self
.
config
,
self
.
dtype
)
...
...
@@ -747,8 +740,8 @@ class FlaxMarianEncoder(nn.Module):
class
FlaxMarianDecoder
(
nn
.
Module
):
config
:
MarianConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -757,13 +750,6 @@ class FlaxMarianDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
FlaxMarianDecoderLayerCollection
(
self
.
config
,
self
.
dtype
)
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
2596f95e
...
...
@@ -712,8 +712,8 @@ class FlaxMBartClassificationHead(nn.Module):
class
FlaxMBartEncoder
(
nn
.
Module
):
config
:
MBartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -723,13 +723,6 @@ class FlaxMBartEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
@@ -787,8 +780,8 @@ class FlaxMBartEncoder(nn.Module):
class
FlaxMBartDecoder
(
nn
.
Module
):
config
:
MBartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -798,13 +791,6 @@ class FlaxMBartDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
src/transformers/models/pegasus/modeling_flax_pegasus.py
View file @
2596f95e
...
...
@@ -677,8 +677,8 @@ class FlaxPegasusDecoderLayerCollection(nn.Module):
class
FlaxPegasusEncoder
(
nn
.
Module
):
config
:
PegasusConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -688,13 +688,6 @@ class FlaxPegasusEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
)
...
...
@@ -746,8 +739,8 @@ class FlaxPegasusEncoder(nn.Module):
class
FlaxPegasusDecoder
(
nn
.
Module
):
config
:
PegasusConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -757,13 +750,6 @@ class FlaxPegasusDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
)
...
...
src/transformers/models/t5/modeling_flax_t5.py
View file @
2596f95e
...
...
@@ -709,19 +709,12 @@ class FlaxT5BlockCollection(nn.Module):
class
FlaxT5Stack
(
nn
.
Module
):
config
:
T5Config
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
def
setup
(
self
):
self
.
causal
=
self
.
config
.
causal
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
self
.
config
.
d_model
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
block
=
FlaxT5BlockCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
self
.
final_layer_norm
=
FlaxT5LayerNorm
(
self
.
config
.
d_model
,
eps
=
self
.
config
.
layer_norm_epsilon
,
dtype
=
self
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment