Unverified Commit a7cb92aa authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

fix / skip (for now) some tests before switch to torch 2.2 (#28838)



* fix / skip some tests before we can switch to torch 2.2

* style

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 0e75aeef
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from transformers import MegaConfig, is_torch_available from transformers import MegaConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
TestCasePlus, TestCasePlus,
is_flaky,
require_torch, require_torch,
require_torch_fp16, require_torch_fp16,
slow, slow,
...@@ -534,6 +535,18 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -534,6 +535,18 @@ class MegaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
self.model_tester = MegaModelTester(self) self.model_tester = MegaModelTester(self)
self.config_tester = ConfigTester(self, config_class=MegaConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=MegaConfig, hidden_size=37)
# TODO: @ydshieh
@is_flaky(description="Sometimes gives `AssertionError` on expected outputs")
def test_pipeline_fill_mask(self):
super().test_pipeline_fill_mask()
# TODO: @ydshieh
@is_flaky(
description="Sometimes gives `RuntimeError: probability tensor contains either `inf`, `nan` or element < 0`"
)
def test_pipeline_text_generation(self):
super().test_pipeline_text_generation()
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
......
...@@ -176,6 +176,11 @@ class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -176,6 +176,11 @@ class VitsModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
# TODO: @ydshieh
@is_flaky(description="torch 2.2.0 gives `Timeout >120.0s`")
def test_pipeline_feature_extraction(self):
super().test_pipeline_feature_extraction()
@unittest.skip("Need to fix this after #26538") @unittest.skip("Need to fix this after #26538")
def test_model_forward(self): def test_model_forward(self):
set_seed(12345) set_seed(12345)
......
...@@ -155,9 +155,11 @@ class ModelOutputTester(unittest.TestCase): ...@@ -155,9 +155,11 @@ class ModelOutputTester(unittest.TestCase):
if is_torch_greater_or_equal_than_2_2: if is_torch_greater_or_equal_than_2_2:
self.assertEqual( self.assertEqual(
pytree.treespec_dumps(actual_tree_spec), pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]', '[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": "[\\"a\\", \\"c\\"]", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
) )
# TODO: @ydshieh
@unittest.skip("CPU OOM")
@require_torch @require_torch
def test_export_serialization(self): def test_export_serialization(self):
if not is_torch_greater_or_equal_than_2_2: if not is_torch_greater_or_equal_than_2_2:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment