Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
013462c5
Unverified
Commit
013462c5
authored
Jun 02, 2022
by
Arthur
Committed by
GitHub
Jun 02, 2022
Browse files
fix OPT-Flax CI tests (#17512)
parent
2f59ad16
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
tests/models/opt/test_modeling_flax_opt.py
tests/models/opt/test_modeling_flax_opt.py
+3
-2
No files found.
tests/models/opt/test_modeling_flax_opt.py
View file @
013462c5
...
@@ -269,13 +269,14 @@ class FlaxOPTEmbeddingsTest(unittest.TestCase):
...
@@ -269,13 +269,14 @@ class FlaxOPTEmbeddingsTest(unittest.TestCase):
[
6.4783
,
-
1.9913
,
-
10.7926
,
-
2.3336
,
1.5092
,
-
0.9974
,
-
6.8213
,
1.3477
,
1.3477
],
[
6.4783
,
-
1.9913
,
-
10.7926
,
-
2.3336
,
1.5092
,
-
0.9974
,
-
6.8213
,
1.3477
,
1.3477
],
]
]
)
)
self
.
assertTrue
(
jnp
.
allclose
(
logits
,
logits_meta
,
atol
=
1
e-
4
))
self
.
assertTrue
(
jnp
.
allclose
(
logits
,
logits_meta
,
atol
=
4
e-
2
))
model
=
jax
.
jit
(
model
)
model
=
jax
.
jit
(
model
)
logits
=
model
(
inputs
.
input_ids
,
attention_mask
=
inputs
.
attention_mask
)[
0
].
mean
(
axis
=-
1
)
logits
=
model
(
inputs
.
input_ids
,
attention_mask
=
inputs
.
attention_mask
)[
0
].
mean
(
axis
=-
1
)
self
.
assertTrue
(
jnp
.
allclose
(
logits
,
logits_meta
,
atol
=
1
e-
4
))
self
.
assertTrue
(
jnp
.
allclose
(
logits
,
logits_meta
,
atol
=
4
e-
2
))
@
require_flax
@
slow
@
slow
class
FlaxOPTGenerationTest
(
unittest
.
TestCase
):
class
FlaxOPTGenerationTest
(
unittest
.
TestCase
):
@
property
@
property
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment