Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7c6cd0ac
Unverified
Commit
7c6cd0ac
authored
Oct 18, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 18, 2021
Browse files
up (#14046)
parent
82b62fa6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
5 deletions
+0
-5
tests/test_modeling_flax_clip.py
tests/test_modeling_flax_clip.py
+0
-5
No files found.
tests/test_modeling_flax_clip.py
View file @
7c6cd0ac
...
@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs
=
pt_outputs
[
1
:]
fx_outputs
=
fx_model
(
**
prepared_inputs_dict
).
to_tuple
()
fx_outputs
=
fx_model
(
**
prepared_inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
...
@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
pt_outputs
=
pt_model
(
**
pt_inputs
).
to_tuple
()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs
=
pt_outputs
[
1
:]
fx_outputs
=
fx_model
(
**
prepared_inputs_dict
).
to_tuple
()
fx_outputs
=
fx_model
(
**
prepared_inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs
),
"Output lengths differ between Flax and PyTorch"
)
...
@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
...
@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
).
to_tuple
()
pt_outputs_loaded
=
pt_model_loaded
(
**
pt_inputs
).
to_tuple
()
pt_outputs_loaded
=
pt_outputs_loaded
[
1
:]
self
.
assertEqual
(
self
.
assertEqual
(
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
len
(
fx_outputs
),
len
(
pt_outputs_loaded
),
"Output lengths differ between Flax and PyTorch"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment