Unverified Commit 56b03c96 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Fix TF CTC tests (#21606)

parent cbecf121
...@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960") model = TFHubertModel.from_pretrained("facebook/hubert-base-ls960")
self.assertIsNotNone(model) self.assertIsNotNone(model)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2
super().test_keras_fit()
self.model_tester.batch_size = default_batch_size
@require_tf @require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
...@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@slow
def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC # We override here as passing a full batch of 13 samples results in OOM errors for CTC
# TODO: fix me
@unittest.skip(reason="Crashing on CI, temporarily skipped")
def test_dataset_conversion(self): def test_dataset_conversion(self):
default_batch_size = self.model_tester.batch_size default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2 self.model_tester.batch_size = 2
super().test_dataset_conversion() super().test_dataset_conversion()
self.model_tester.batch_size = default_batch_size self.model_tester.batch_size = default_batch_size
@slow
def test_model_from_pretrained(self):
model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
self.assertIsNotNone(model)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC # We override here as passing a full batch of 13 samples results in OOM errors for CTC
def test_keras_fit(self): def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size default_batch_size = self.model_tester.batch_size
......
...@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
def test_keras_fit(self): def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2 self.model_tester.batch_size = 2
super().test_dataset_conversion() super().test_keras_fit()
self.model_tester.batch_size = default_batch_size self.model_tester.batch_size = default_batch_size
...@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
def test_keras_fit(self): def test_keras_fit(self):
default_batch_size = self.model_tester.batch_size default_batch_size = self.model_tester.batch_size
self.model_tester.batch_size = 2 self.model_tester.batch_size = 2
super().test_dataset_conversion() super().test_keras_fit()
self.model_tester.batch_size = default_batch_size self.model_tester.batch_size = default_batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment