Unverified Commit 6ccea048 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix T5 model parallel tes (#9107)

k
parent 59da3f27
...@@ -484,9 +484,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -484,9 +484,7 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
all_parallelizable_model_classes = ( all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
(T5Model, T5ForConditionalGeneration, T5EncoderModel) if is_torch_available() else ()
)
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
...@@ -689,6 +687,8 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -689,6 +687,8 @@ class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False test_pruning = False
test_torchscript = True test_torchscript = True
test_resize_embeddings = False test_resize_embeddings = False
test_model_parallel = True
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
def setUp(self): def setUp(self):
self.model_tester = T5EncoderOnlyModelTester(self) self.model_tester = T5EncoderOnlyModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment