Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2596f95e
"...MOD_SRC/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "33b958e11280f632c5b8007ca8b50f97fe9f9551"
Unverified
Commit
2596f95e
authored
Mar 07, 2022
by
Sanchit Gandhi
Committed by
GitHub
Mar 07, 2022
Browse files
Fix Embedding Module Bug in Flax Models (#15920)
parent
1a62b25c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
13 additions
and
104 deletions
+13
-104
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+2
-16
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
...ransformers/models/blenderbot/modeling_flax_blenderbot.py
+2
-16
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
...models/blenderbot_small/modeling_flax_blenderbot_small.py
+2
-16
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+2
-16
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+2
-16
src/transformers/models/pegasus/modeling_flax_pegasus.py
src/transformers/models/pegasus/modeling_flax_pegasus.py
+2
-16
src/transformers/models/t5/modeling_flax_t5.py
src/transformers/models/t5/modeling_flax_t5.py
+1
-8
No files found.
src/transformers/models/bart/modeling_flax_bart.py
View file @
2596f95e
...
...
@@ -697,8 +697,8 @@ class FlaxBartClassificationHead(nn.Module):
class
FlaxBartEncoder
(
nn
.
Module
):
config
:
BartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -708,13 +708,6 @@ class FlaxBartEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
@@ -768,8 +761,8 @@ class FlaxBartEncoder(nn.Module):
class
FlaxBartDecoder
(
nn
.
Module
):
config
:
BartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -779,13 +772,6 @@ class FlaxBartDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
View file @
2596f95e
...
...
@@ -661,8 +661,8 @@ class FlaxBlenderbotDecoderLayerCollection(nn.Module):
class
FlaxBlenderbotEncoder
(
nn
.
Module
):
config
:
BlenderbotConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -672,13 +672,6 @@ class FlaxBlenderbotEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
@@ -730,8 +723,8 @@ class FlaxBlenderbotEncoder(nn.Module):
class
FlaxBlenderbotDecoder
(
nn
.
Module
):
config
:
BlenderbotConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -741,13 +734,6 @@ class FlaxBlenderbotDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
View file @
2596f95e
...
...
@@ -674,8 +674,8 @@ class FlaxBlenderbotSmallDecoderLayerCollection(nn.Module):
class
FlaxBlenderbotSmallEncoder
(
nn
.
Module
):
config
:
BlenderbotSmallConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -685,13 +685,6 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
@@ -742,8 +735,8 @@ class FlaxBlenderbotSmallEncoder(nn.Module):
class
FlaxBlenderbotSmallDecoder
(
nn
.
Module
):
config
:
BlenderbotSmallConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -753,13 +746,6 @@ class FlaxBlenderbotSmallDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
nn
.
Embed
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
2596f95e
...
...
@@ -684,8 +684,8 @@ class FlaxMarianDecoderLayerCollection(nn.Module):
class
FlaxMarianEncoder
(
nn
.
Module
):
config
:
MarianConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -694,13 +694,6 @@ class FlaxMarianEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
FlaxMarianEncoderLayerCollection
(
self
.
config
,
self
.
dtype
)
...
...
@@ -747,8 +740,8 @@ class FlaxMarianEncoder(nn.Module):
class
FlaxMarianDecoder
(
nn
.
Module
):
config
:
MarianConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -757,13 +750,6 @@ class FlaxMarianDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
)
self
.
layers
=
FlaxMarianDecoderLayerCollection
(
self
.
config
,
self
.
dtype
)
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
2596f95e
...
...
@@ -712,8 +712,8 @@ class FlaxMBartClassificationHead(nn.Module):
class
FlaxMBartEncoder
(
nn
.
Module
):
config
:
MBartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -723,13 +723,6 @@ class FlaxMBartEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
@@ -787,8 +780,8 @@ class FlaxMBartEncoder(nn.Module):
class
FlaxMBartDecoder
(
nn
.
Module
):
config
:
MBartConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -798,13 +791,6 @@ class FlaxMBartDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
...
...
src/transformers/models/pegasus/modeling_flax_pegasus.py
View file @
2596f95e
...
...
@@ -677,8 +677,8 @@ class FlaxPegasusDecoderLayerCollection(nn.Module):
class
FlaxPegasusEncoder
(
nn
.
Module
):
config
:
PegasusConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -688,13 +688,6 @@ class FlaxPegasusEncoder(nn.Module):
self
.
max_source_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
)
...
...
@@ -746,8 +739,8 @@ class FlaxPegasusEncoder(nn.Module):
class
FlaxPegasusDecoder
(
nn
.
Module
):
config
:
PegasusConfig
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
def
setup
(
self
):
self
.
dropout_layer
=
nn
.
Dropout
(
rate
=
self
.
config
.
dropout
)
...
...
@@ -757,13 +750,6 @@ class FlaxPegasusDecoder(nn.Module):
self
.
max_target_positions
=
self
.
config
.
max_position_embeddings
self
.
embed_scale
=
math
.
sqrt
(
self
.
config
.
d_model
)
if
self
.
config
.
scale_embedding
else
1.0
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
embed_dim
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
embed_positions
=
create_sinusoidal_positions
(
self
.
config
.
max_position_embeddings
,
embed_dim
,
dtype
=
self
.
dtype
)
...
...
src/transformers/models/t5/modeling_flax_t5.py
View file @
2596f95e
...
...
@@ -709,19 +709,12 @@ class FlaxT5BlockCollection(nn.Module):
class
FlaxT5Stack
(
nn
.
Module
):
config
:
T5Config
embed_tokens
:
Optional
[
nn
.
Embed
]
=
None
embed_tokens
:
nn
.
Embed
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
def
setup
(
self
):
self
.
causal
=
self
.
config
.
causal
if
self
.
embed_tokens
is
None
:
self
.
embed_tokens
=
nn
.
Embed
(
self
.
config
.
vocab_size
,
self
.
config
.
d_model
,
embedding_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
init_std
),
)
self
.
block
=
FlaxT5BlockCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
self
.
final_layer_norm
=
FlaxT5LayerNorm
(
self
.
config
.
d_model
,
eps
=
self
.
config
.
layer_norm_epsilon
,
dtype
=
self
.
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment