Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cbaaa2f6
Unverified
Commit
cbaaa2f6
authored
Jan 19, 2023
by
Joao Gante
Committed by
GitHub
Jan 19, 2023
Browse files
Flax dtype-dependent numerical masking (#21197)
parent
0b86e330
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
20 additions
and
20 deletions
+20
-20
src/transformers/models/albert/modeling_flax_albert.py
src/transformers/models/albert/modeling_flax_albert.py
+1
-1
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+1
-1
src/transformers/models/bert/modeling_flax_bert.py
src/transformers/models/bert/modeling_flax_bert.py
+1
-1
src/transformers/models/big_bird/modeling_flax_big_bird.py
src/transformers/models/big_bird/modeling_flax_big_bird.py
+1
-1
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
...ransformers/models/blenderbot/modeling_flax_blenderbot.py
+1
-1
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
...models/blenderbot_small/modeling_flax_blenderbot_small.py
+1
-1
src/transformers/models/clip/modeling_flax_clip.py
src/transformers/models/clip/modeling_flax_clip.py
+1
-1
src/transformers/models/electra/modeling_flax_electra.py
src/transformers/models/electra/modeling_flax_electra.py
+1
-1
src/transformers/models/gpt2/modeling_flax_gpt2.py
src/transformers/models/gpt2/modeling_flax_gpt2.py
+1
-1
src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
+1
-1
src/transformers/models/gptj/modeling_flax_gptj.py
src/transformers/models/gptj/modeling_flax_gptj.py
+1
-1
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+1
-1
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+1
-1
src/transformers/models/opt/modeling_flax_opt.py
src/transformers/models/opt/modeling_flax_opt.py
+1
-1
src/transformers/models/pegasus/modeling_flax_pegasus.py
src/transformers/models/pegasus/modeling_flax_pegasus.py
+1
-1
src/transformers/models/roberta/modeling_flax_roberta.py
src/transformers/models/roberta/modeling_flax_roberta.py
+1
-1
src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
...oberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
+1
-1
src/transformers/models/roformer/modeling_flax_roformer.py
src/transformers/models/roformer/modeling_flax_roformer.py
+1
-1
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+1
-1
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
...nsformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
+1
-1
No files found.
src/transformers/models/albert/modeling_flax_albert.py
View file @
cbaaa2f6
...
@@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module):
...
@@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/bart/modeling_flax_bart.py
View file @
cbaaa2f6
...
@@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
...
@@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/bert/modeling_flax_bert.py
View file @
cbaaa2f6
...
@@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module):
...
@@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/big_bird/modeling_flax_big_bird.py
View file @
cbaaa2f6
...
@@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
...
@@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
View file @
cbaaa2f6
...
@@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
...
@@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
View file @
cbaaa2f6
...
@@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
...
@@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/clip/modeling_flax_clip.py
View file @
cbaaa2f6
...
@@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module):
...
@@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e4
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/electra/modeling_flax_electra.py
View file @
cbaaa2f6
...
@@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module):
...
@@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/gpt2/modeling_flax_gpt2.py
View file @
cbaaa2f6
...
@@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module):
...
@@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e4
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
View file @
cbaaa2f6
...
@@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
...
@@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
# usual dot product attention
# usual dot product attention
...
...
src/transformers/models/gptj/modeling_flax_gptj.py
View file @
cbaaa2f6
...
@@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module):
...
@@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
# usual dot product attention
# usual dot product attention
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
cbaaa2f6
...
@@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
...
@@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
cbaaa2f6
...
@@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
...
@@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/opt/modeling_flax_opt.py
View file @
cbaaa2f6
...
@@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
...
@@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/pegasus/modeling_flax_pegasus.py
View file @
cbaaa2f6
...
@@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
...
@@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/roberta/modeling_flax_roberta.py
View file @
cbaaa2f6
...
@@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module):
...
@@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
View file @
cbaaa2f6
...
@@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
...
@@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/roformer/modeling_flax_roformer.py
View file @
cbaaa2f6
...
@@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
...
@@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
View file @
cbaaa2f6
...
@@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module):
...
@@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
float
(
"-inf"
)
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
View file @
cbaaa2f6
...
@@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
...
@@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
)
else
:
else
:
attention_bias
=
None
attention_bias
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment