Unverified Commit 33f36c86 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a main_input_name attribute to all models (#14803)



* Add a main_input_name attribute to all models

* Fix tests

* Wtf Vs Code?

* Update src/transformers/models/imagegpt/modeling_imagegpt.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Style

* Fix copies
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0940e9b2
......@@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
def __init__(
self,
......
......@@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None
......
......@@ -17,7 +17,6 @@
import inspect
import os
import re
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
......@@ -376,11 +375,10 @@ class ModuleUtilsMixin:
Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
if self.main_input_name in input_dict:
return input_dict[self.main_input_name].numel()
else:
warnings.warn(
logger.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0
......@@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
"""
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None
......
......@@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel):
config_class = BeitConfig
base_model_prefix = "beit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
config_class = BeitConfig
base_model_prefix = "beit"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
......
......@@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module):
class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
......
......@@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPVisionConfig
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(
......
......@@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
config_class = DeiTConfig
base_model_prefix = "deit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module):
std = self.config.init_std
......
......@@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel):
config_class = HubertConfig
base_model_prefix = "hubert"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
......
......@@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
config_class = HubertConfig
base_model_prefix = "hubert"
main_input_name = "input_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
......
......@@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
config_class = ImageGPTConfig
load_tf_weights = load_tf_weights_in_imagegpt
base_model_prefix = "transformer"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs):
......
......@@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel):
config_class = PerceiverConfig
base_model_prefix = "perceiver"
main_input_name = "inputs"
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel):
config_class = SegformerConfig
base_model_prefix = "segformer"
main_input_name = "pixel_values"
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel):
config_class = SEWConfig
base_model_prefix = "sew"
main_input_name = "input_values"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
......
......@@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class = SEWDConfig
base_model_prefix = "sew-d"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
......
......@@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"""
config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder"
main_input_name = "input_values"
def __init__(
self,
......
......@@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module):
class Speech2TextPreTrainedModel(PreTrainedModel):
config_class = Speech2TextConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
def _init_weights(self, module):
......
......@@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
config_class = UniSpeechConfig
base_model_prefix = "unispeech"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
......
......@@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
config_class = UniSpeechSatConfig
base_model_prefix = "unispeech_sat"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment