"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ebdb185befaa821304d461ed6aa20a17e4dc3aa2"
Unverified Commit 33f36c86 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a main_input_name attribute to all models (#14803)



* Add a main_input_name attribute to all models

* Fix tests

* Wtf Vs Code?

* Update src/transformers/models/imagegpt/modeling_imagegpt.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Style

* Fix copies
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 0940e9b2
...@@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids"
def __init__( def __init__(
self, self,
......
...@@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights # a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings). # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None _keys_to_ignore_on_load_missing = None
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import inspect import inspect
import os import os
import re import re
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -376,11 +375,10 @@ class ModuleUtilsMixin: ...@@ -376,11 +375,10 @@ class ModuleUtilsMixin:
Returns: Returns:
:obj:`int`: The total number of tokens. :obj:`int`: The total number of tokens.
""" """
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key] if self.main_input_name in input_dict:
if token_inputs: return input_dict[self.main_input_name].numel()
return sum([token_input.numel() for token_input in token_inputs])
else: else:
warnings.warn( logger.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
) )
return 0 return 0
...@@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model. derived classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. - **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for
NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models).
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
main_input_name = "input_ids"
# a list of re pattern of tensor names to ignore from the model when loading the model weights # a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings). # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_missing = None _keys_to_ignore_on_load_missing = None
......
...@@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel): ...@@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel):
config_class = BeitConfig config_class = BeitConfig
base_model_prefix = "beit" base_model_prefix = "beit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): ...@@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel):
config_class = BeitConfig config_class = BeitConfig
base_model_prefix = "beit" base_model_prefix = "beit"
main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
......
...@@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module): ...@@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module):
class CLIPVisionModel(CLIPPreTrainedModel): class CLIPVisionModel(CLIPPreTrainedModel):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: CLIPVisionConfig): def __init__(self, config: CLIPVisionConfig):
super().__init__(config) super().__init__(config)
......
...@@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): ...@@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPVisionConfig config_class = CLIPVisionConfig
main_input_name = "pixel_values"
module_class: nn.Module = None module_class: nn.Module = None
def __init__( def __init__(
......
...@@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
config_class = DeiTConfig config_class = DeiTConfig
base_model_prefix = "deit" base_model_prefix = "deit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module): ...@@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel): class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig config_class = DetrConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel): ...@@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel):
config_class = HubertConfig config_class = HubertConfig
base_model_prefix = "hubert" base_model_prefix = "hubert"
main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
......
...@@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel): ...@@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel):
config_class = HubertConfig config_class = HubertConfig
base_model_prefix = "hubert" base_model_prefix = "hubert"
main_input_name = "input_values"
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
......
...@@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel): ...@@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
config_class = ImageGPTConfig config_class = ImageGPTConfig
load_tf_weights = load_tf_weights_in_imagegpt load_tf_weights = load_tf_weights_in_imagegpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
main_input_name = "input_ids"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
......
...@@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel): ...@@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel):
config_class = PerceiverConfig config_class = PerceiverConfig
base_model_prefix = "perceiver" base_model_prefix = "perceiver"
main_input_name = "inputs"
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel): ...@@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel):
config_class = SegformerConfig config_class = SegformerConfig
base_model_prefix = "segformer" base_model_prefix = "segformer"
main_input_name = "pixel_values"
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel): ...@@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel):
config_class = SEWConfig config_class = SEWConfig
base_model_prefix = "sew" base_model_prefix = "sew"
main_input_name = "input_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
......
...@@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel): ...@@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel):
config_class = SEWDConfig config_class = SEWDConfig
base_model_prefix = "sew-d" base_model_prefix = "sew-d"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
......
...@@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
""" """
config_class = SpeechEncoderDecoderConfig config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder" base_model_prefix = "speech_encoder_decoder"
main_input_name = "input_values"
def __init__( def __init__(
self, self,
......
...@@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module):
class Speech2TextPreTrainedModel(PreTrainedModel): class Speech2TextPreTrainedModel(PreTrainedModel):
config_class = Speech2TextConfig config_class = Speech2TextConfig
base_model_prefix = "model" base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): ...@@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
config_class = UniSpeechConfig config_class = UniSpeechConfig
base_model_prefix = "unispeech" base_model_prefix = "unispeech"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
......
...@@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): ...@@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
config_class = UniSpeechSatConfig config_class = UniSpeechSatConfig
base_model_prefix = "unispeech_sat" base_model_prefix = "unispeech_sat"
main_input_name = "input_values"
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment