Unverified Commit 25245ec2 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Rename test_model_common_attributes -> test_model_get_set_embeddings (#31321)

* Rename to test_model_common_attributes
The method name is misleading - it is testing being able to get and set embeddings, not common attributes to all models

* Explicitly skip
parent c1be42f6
......@@ -567,7 +567,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
......@@ -921,7 +921,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
# Wav2Vec2 has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
pass
def test_retain_grad_hidden_states_attentions(self):
......
......@@ -557,7 +557,7 @@ class Wav2Vec2BertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
# Ignore copy
@unittest.skip(reason="Wav2Vec2Bert has no inputs_embeds")
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
pass
# Ignore copy
......
......@@ -528,7 +528,7 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
# Wav2Vec2Conformer has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
pass
@is_pt_flax_cross_test
......
......@@ -387,7 +387,7 @@ class WavLMModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# WavLM has no inputs_embeds
# and thus the `get_input_embeddings` fn
# is not implemented
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
pass
# WavLM uses PyTorch's multi-head-attention class
......
......@@ -3153,7 +3153,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self.assertTrue((outputs_embeds == outputs).all())
# Needs to override as the encoder input embedding is a Conv1d
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......
......@@ -161,7 +161,7 @@ class XCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......@@ -561,7 +561,7 @@ class XCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pass
@unittest.skip(reason="XCLIPModel does not have input/output embeddings")
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="XCLIPModel does not support feedforward chunking")
......
......@@ -210,7 +210,7 @@ class YolosModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# YOLOS does not use inputs_embeds
pass
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......
......@@ -1971,13 +1971,17 @@ class ModelTesterMixin:
# Check that the model can still do a forward pass successfully (every parameter should be resized)
model(**self._prepare_for_class(inputs_dict, model_class))
def test_model_common_attributes(self):
def test_model_get_set_embeddings(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Embedding, AdaptiveEmbedding))
model.set_input_embeddings(nn.Embedding(10, 10))
new_input_embedding_layer = nn.Embedding(10, 10)
model.set_input_embeddings(new_input_embedding_layer)
self.assertEqual(model.get_input_embeddings(), new_input_embedding_layer)
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment