Unverified Commit b25b92ac authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

update jax version and re-enable some tests (#16254)

parent 5709a204
......@@ -112,7 +112,7 @@ _deps = [
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
"jax>=0.2.8",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"jieba",
"nltk",
......
......@@ -22,7 +22,7 @@ deps = {
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"jieba": "jieba",
"nltk": "nltk",
......
......@@ -691,10 +691,6 @@ class FlaxWav2Vec2GPT2ModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
pass
@require_flax
class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
......@@ -811,7 +807,3 @@ class FlaxWav2Vec2BartModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
self.assert_almost_equals(fx_logits, pt_logits_loaded.numpy(), 4e-2)
@unittest.skip("Re-enable this test once this issue is fixed: https://github.com/google/jax/issues/9941")
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment