Unverified Commit 25245ec2 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Rename test_model_common_attributes -> test_model_get_set_embeddings (#31321)

* Rename to test_model_common_attributes
The method name is misleading - it is testing being able to get and set embeddings, not common attributes to all models

* Explicitly skip
parent c1be42f6
...@@ -567,7 +567,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -567,7 +567,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# Wav2Vec2 has no inputs_embeds # Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn # and thus the `get_input_embeddings` fn
# is not implemented # is not implemented
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
pass pass
@is_pt_flax_cross_test @is_pt_flax_cross_test
...@@ -921,7 +921,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -921,7 +921,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
# Wav2Vec2 has no inputs_embeds # Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn # and thus the `get_input_embeddings` fn
# is not implemented # is not implemented
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
pass pass
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
......
...@@ -557,7 +557,7 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -557,7 +557,7 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
# Ignore copy # Ignore copy
@unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds") @unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds")
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
pass pass
# Ignore copy # Ignore copy
......
...@@ -528,7 +528,7 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest ...@@ -528,7 +528,7 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
# Wav2Vec2Conformer has no inputs_embeds # Wav2Vec2Conformer has no inputs_embeds
# and thus the `get_input_embeddings` fn # and thus the `get_input_embeddings` fn
# is not implemented # is not implemented
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
pass pass
@is_pt_flax_cross_test @is_pt_flax_cross_test
......
...@@ -387,7 +387,7 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -387,7 +387,7 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# WavLM has no inputs_embeds # WavLM has no inputs_embeds
# and thus the `get_input_embeddings` fn # and thus the `get_input_embeddings` fn
# is not implemented # is not implemented
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
pass pass
# WavLM uses PyTorch's multi-head-attention class # WavLM uses PyTorch's multi-head-attention class
......
...@@ -3153,7 +3153,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. ...@@ -3153,7 +3153,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self.assertTrue((outputs_embeds == outputs).all()) self.assertTrue((outputs_embeds == outputs).all())
# Needs to override as the encoder input embedding is a Conv1d # Needs to override as the encoder input embedding is a Conv1d
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
...@@ -161,7 +161,7 @@ class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -161,7 +161,7 @@ class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -561,7 +561,7 @@ class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -561,7 +561,7 @@ class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pass pass
@unittest.skip(reason="XCLIPModel does not have input/output embeddings") @unittest.skip(reason="XCLIPModel does not have input/output embeddings")
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
pass pass
@unittest.skip(reason="XCLIPModel does not support feedforward chunking") @unittest.skip(reason="XCLIPModel does not support feedforward chunking")
......
...@@ -210,7 +210,7 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -210,7 +210,7 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# YOLOS does not use inputs_embeds # YOLOS does not use inputs_embeds
pass pass
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
...@@ -1971,13 +1971,17 @@ class ModelTesterMixin: ...@@ -1971,13 +1971,17 @@ class ModelTesterMixin:
# Check that the model can still do a forward pass successfully (every parameter should be resized) # Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class)) model(**self._prepare_for_class(inputs_dict, model_class))
def test_model_common_attributes(self): def test_model_get_set_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding)) self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(nn.Embedding(10, 10))
new_input_embedding_layer = nn.Embedding(10, 10)
model.set_input_embeddings(new_input_embedding_layer)
self.assertEqual(model.get_input_embeddings(), new_input_embedding_layer)
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment