Unverified Commit 33f36c86 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a main_input_name attribute to all models (#14803)



* Add a main_input_name attribute to all models

* Fix tests

* Wtf Vs Code?

* Update src/transformers/models/imagegpt/modeling_imagegpt.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Style

* Fix copies
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0940e9b2
...@@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel):
""" """
config_class = VisionEncoderDecoderConfig config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder" base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
module_class = FlaxVisionEncoderDecoderModule module_class = FlaxVisionEncoderDecoderModule
def __init__( def __init__(
......
...@@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
""" """
config_class = VisionEncoderDecoderConfig config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder" base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values"
def __init__( def __init__(
self, self,
......
...@@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel): ...@@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
......
...@@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel): ...@@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values"
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
......
...@@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel):
config_class = ViTConfig config_class = ViTConfig
base_model_prefix = "vit" base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): ...@@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix: str = "wav2vec2" base_model_prefix: str = "wav2vec2"
main_input_name = "input_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(
......
...@@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel): ...@@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2" base_model_prefix = "wav2vec2"
main_input_name = "input_values"
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
......
...@@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
config_class = Wav2Vec2Config config_class = Wav2Vec2Config
base_model_prefix = "wav2vec2" base_model_prefix = "wav2vec2"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
......
...@@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel): ...@@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
config_class = WavLMConfig config_class = WavLMConfig
base_model_prefix = "wavlm" base_model_prefix = "wavlm"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
......
...@@ -1315,6 +1315,13 @@ class ModelTesterMixin: ...@@ -1315,6 +1315,13 @@ class ModelTesterMixin:
x = model.get_output_embeddings() x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear)) self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_correct_missing_keys(self): def test_correct_missing_keys(self):
if not self.test_missing_keys: if not self.test_missing_keys:
return return
......
...@@ -778,6 +778,13 @@ class FlaxModelTesterMixin: ...@@ -778,6 +778,13 @@ class FlaxModelTesterMixin:
for name, type_ in types.items(): for name, type_ in types.items():
self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.")
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "__call__"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def test_headmasking(self): def test_headmasking(self):
if not self.test_head_masking: if not self.test_head_masking:
return return
......
...@@ -1183,6 +1183,13 @@ class TFModelTesterMixin: ...@@ -1183,6 +1183,13 @@ class TFModelTesterMixin:
else: else:
new_model_without_prefix(input_ids) new_model_without_prefix(input_ids)
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "call"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
def _generate_random_bad_tokens(self, num_bad_tokens, model): def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens # special tokens cannot be bad tokens
special_tokens = [] special_tokens = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment