"examples/vscode:/vscode.git/clone" did not exist on "bd54ed2ed7f578e4122f3e6d536fbe3c9bc76de1"
Unverified Commit c04619ec authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Enable more test_torchscript (#16679)



* update _create_and_check_torchscript

* Enable test_torchscript

* clear_class_registry
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3918d6a9
...@@ -273,7 +273,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes ...@@ -273,7 +273,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Tes
is_encoder_decoder = True is_encoder_decoder = True
test_pruning = False test_pruning = False
test_missing_keys = False test_missing_keys = False
test_torchscript = True
input_name = "input_features" input_name = "input_features"
......
...@@ -229,7 +229,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -229,7 +229,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
else None else None
) )
test_pruning = False test_pruning = False
test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = False test_head_masking = False
......
...@@ -177,7 +177,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -177,7 +177,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -512,7 +512,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -512,7 +512,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
fx_compatible = True fx_compatible = True
all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () all_parallelizable_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_model_parallel = True test_model_parallel = True
is_encoder_decoder = True is_encoder_decoder = True
...@@ -777,7 +776,6 @@ class T5EncoderOnlyModelTester: ...@@ -777,7 +776,6 @@ class T5EncoderOnlyModelTester:
class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): class T5EncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5EncoderModel,) if is_torch_available() else () all_model_classes = (T5EncoderModel,) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = True
test_resize_embeddings = False test_resize_embeddings = False
test_model_parallel = True test_model_parallel = True
all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else () all_parallelizable_model_classes = (T5EncoderModel,) if is_torch_available() else ()
......
...@@ -422,7 +422,6 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -422,7 +422,6 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
else None else None
) )
test_pruning = False test_pruning = False
test_torchscript = False
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = False test_head_masking = False
......
...@@ -617,19 +617,21 @@ class ModelTesterMixin: ...@@ -617,19 +617,21 @@ class ModelTesterMixin:
model.eval() model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class) inputs = self._prepare_for_class(inputs_dict, model_class)
main_input_name = model_class.main_input_name
try: try:
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward
input_ids = inputs["input_ids"] main_input = inputs[main_input_name]
attention_mask = inputs["attention_mask"] attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"] decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"] decoder_attention_mask = inputs["decoder_attention_mask"]
traced_model = torch.jit.trace( traced_model = torch.jit.trace(
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask) model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
) )
else: else:
input_ids = inputs["input_ids"] main_input = inputs[main_input_name]
traced_model = torch.jit.trace(model, input_ids) traced_model = torch.jit.trace(model, main_input)
except RuntimeError: except RuntimeError:
self.fail("Couldn't trace module.") self.fail("Couldn't trace module.")
......
...@@ -238,7 +238,6 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -238,7 +238,6 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
) )
all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (TransfoXLLMHeadModel,) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = False
test_resize_embeddings = True test_resize_embeddings = True
test_mismatched_shapes = False test_mismatched_shapes = False
......
...@@ -305,7 +305,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -305,7 +305,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False
def setUp(self): def setUp(self):
self.model_tester = UniSpeechModelTester( self.model_tester = UniSpeechModelTester(
......
...@@ -124,7 +124,6 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -124,7 +124,6 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (VanModel, VanForImageClassification) if is_torch_available() else () all_model_classes = (VanModel, VanForImageClassification) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
has_attentions = False has_attentions = False
......
...@@ -158,7 +158,6 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -158,7 +158,6 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_torchscript = False
test_resize_embeddings = False test_resize_embeddings = False
test_head_masking = False test_head_masking = False
......
...@@ -413,7 +413,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -413,7 +413,6 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False
def setUp(self): def setUp(self):
self.model_tester = Wav2Vec2ModelTester(self) self.model_tester = Wav2Vec2ModelTester(self)
...@@ -652,7 +651,6 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -652,7 +651,6 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False
def setUp(self): def setUp(self):
self.model_tester = Wav2Vec2ModelTester( self.model_tester = Wav2Vec2ModelTester(
......
...@@ -316,7 +316,6 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -316,7 +316,6 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False
def setUp(self): def setUp(self):
self.model_tester = WavLMModelTester(self) self.model_tester = WavLMModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment