Unverified Commit 7c6cd0ac authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

up (#14046)

parent 82b62fa6
...@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -480,8 +480,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs = pt_outputs[1:]
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
...@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -525,8 +523,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple() pt_outputs = pt_model(**pt_inputs).to_tuple()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs = pt_outputs[1:]
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
...@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -539,7 +535,6 @@ class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
pt_outputs_loaded = pt_outputs_loaded[1:]
self.assertEqual( self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment