Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cbaaa2f6
Unverified
Commit
cbaaa2f6
authored
Jan 19, 2023
by
Joao Gante
Committed by
GitHub
Jan 19, 2023
Browse files
Flax dtype-dependent numerical masking (#21197)
parent
0b86e330
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
20 additions
and
20 deletions
+20
-20
src/transformers/models/albert/modeling_flax_albert.py
src/transformers/models/albert/modeling_flax_albert.py
+1
-1
src/transformers/models/bart/modeling_flax_bart.py
src/transformers/models/bart/modeling_flax_bart.py
+1
-1
src/transformers/models/bert/modeling_flax_bert.py
src/transformers/models/bert/modeling_flax_bert.py
+1
-1
src/transformers/models/big_bird/modeling_flax_big_bird.py
src/transformers/models/big_bird/modeling_flax_big_bird.py
+1
-1
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
...ransformers/models/blenderbot/modeling_flax_blenderbot.py
+1
-1
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
...models/blenderbot_small/modeling_flax_blenderbot_small.py
+1
-1
src/transformers/models/clip/modeling_flax_clip.py
src/transformers/models/clip/modeling_flax_clip.py
+1
-1
src/transformers/models/electra/modeling_flax_electra.py
src/transformers/models/electra/modeling_flax_electra.py
+1
-1
src/transformers/models/gpt2/modeling_flax_gpt2.py
src/transformers/models/gpt2/modeling_flax_gpt2.py
+1
-1
src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
+1
-1
src/transformers/models/gptj/modeling_flax_gptj.py
src/transformers/models/gptj/modeling_flax_gptj.py
+1
-1
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+1
-1
src/transformers/models/mbart/modeling_flax_mbart.py
src/transformers/models/mbart/modeling_flax_mbart.py
+1
-1
src/transformers/models/opt/modeling_flax_opt.py
src/transformers/models/opt/modeling_flax_opt.py
+1
-1
src/transformers/models/pegasus/modeling_flax_pegasus.py
src/transformers/models/pegasus/modeling_flax_pegasus.py
+1
-1
src/transformers/models/roberta/modeling_flax_roberta.py
src/transformers/models/roberta/modeling_flax_roberta.py
+1
-1
src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
...oberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
+1
-1
src/transformers/models/roformer/modeling_flax_roformer.py
src/transformers/models/roformer/modeling_flax_roformer.py
+1
-1
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+1
-1
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
...nsformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
+1
-1
No files found.
src/transformers/models/albert/modeling_flax_albert.py
View file @
cbaaa2f6
...
...
@@ -245,7 +245,7 @@ class FlaxAlbertSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/bart/modeling_flax_bart.py
View file @
cbaaa2f6
...
...
@@ -371,7 +371,7 @@ class FlaxBartAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/bert/modeling_flax_bert.py
View file @
cbaaa2f6
...
...
@@ -358,7 +358,7 @@ class FlaxBertSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/big_bird/modeling_flax_big_bird.py
View file @
cbaaa2f6
...
...
@@ -380,7 +380,7 @@ class FlaxBigBirdSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/blenderbot/modeling_flax_blenderbot.py
View file @
cbaaa2f6
...
...
@@ -359,7 +359,7 @@ class FlaxBlenderbotAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/blenderbot_small/modeling_flax_blenderbot_small.py
View file @
cbaaa2f6
...
...
@@ -371,7 +371,7 @@ class FlaxBlenderbotSmallAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/clip/modeling_flax_clip.py
View file @
cbaaa2f6
...
...
@@ -315,7 +315,7 @@ class FlaxCLIPAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e4
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/electra/modeling_flax_electra.py
View file @
cbaaa2f6
...
...
@@ -326,7 +326,7 @@ class FlaxElectraSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/gpt2/modeling_flax_gpt2.py
View file @
cbaaa2f6
...
...
@@ -255,7 +255,7 @@ class FlaxGPT2Attention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e4
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
View file @
cbaaa2f6
...
...
@@ -223,7 +223,7 @@ class FlaxGPTNeoSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
# usual dot product attention
...
...
src/transformers/models/gptj/modeling_flax_gptj.py
View file @
cbaaa2f6
...
...
@@ -270,7 +270,7 @@ class FlaxGPTJAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
# usual dot product attention
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
cbaaa2f6
...
...
@@ -381,7 +381,7 @@ class FlaxMarianAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/mbart/modeling_flax_mbart.py
View file @
cbaaa2f6
...
...
@@ -383,7 +383,7 @@ class FlaxMBartAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/opt/modeling_flax_opt.py
View file @
cbaaa2f6
...
...
@@ -245,7 +245,7 @@ class FlaxOPTAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/pegasus/modeling_flax_pegasus.py
View file @
cbaaa2f6
...
...
@@ -375,7 +375,7 @@ class FlaxPegasusAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e9
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/roberta/modeling_flax_roberta.py
View file @
cbaaa2f6
...
...
@@ -319,7 +319,7 @@ class FlaxRobertaSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/roberta_prelayernorm/modeling_flax_roberta_prelayernorm.py
View file @
cbaaa2f6
...
...
@@ -321,7 +321,7 @@ class FlaxRobertaPreLayerNormSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/roformer/modeling_flax_roformer.py
View file @
cbaaa2f6
...
...
@@ -240,7 +240,7 @@ class FlaxRoFormerSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
View file @
cbaaa2f6
...
...
@@ -497,7 +497,7 @@ class FlaxWav2Vec2Attention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
float
(
"-inf"
)
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
View file @
cbaaa2f6
...
...
@@ -329,7 +329,7 @@ class FlaxXLMRobertaSelfAttention(nn.Module):
attention_bias
=
lax
.
select
(
attention_mask
>
0
,
jnp
.
full
(
attention_mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
-
1e10
).
astype
(
self
.
dtype
),
jnp
.
full
(
attention_mask
.
shape
,
jnp
.
finfo
(
self
.
dtype
).
min
).
astype
(
self
.
dtype
),
)
else
:
attention_bias
=
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment