"docs/source/vscode:/vscode.git/clone" did not exist on "42791a5753ea13ead6277b6728537e680c7b05d7"
Unverified Commit 9eda6b52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
parent 53c60bab
...@@ -244,6 +244,15 @@ class FlaxBartModel: ...@@ -244,6 +244,15 @@ class FlaxBartModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBertForMaskedLM: class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -412,6 +421,15 @@ class FlaxCLIPTextModel: ...@@ -412,6 +421,15 @@ class FlaxCLIPTextModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxCLIPTextPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxCLIPVisionModel: class FlaxCLIPVisionModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -421,6 +439,15 @@ class FlaxCLIPVisionModel: ...@@ -421,6 +439,15 @@ class FlaxCLIPVisionModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxCLIPVisionPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxElectraForMaskedLM: class FlaxElectraForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -507,6 +534,15 @@ class FlaxGPT2Model: ...@@ -507,6 +534,15 @@ class FlaxGPT2Model:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxGPT2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxRobertaForMaskedLM: class FlaxRobertaForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -588,6 +624,15 @@ class FlaxT5Model: ...@@ -588,6 +624,15 @@ class FlaxT5Model:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxT5PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification: class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
...@@ -600,3 +645,12 @@ class FlaxViTModel: ...@@ -600,3 +645,12 @@ class FlaxViTModel:
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxViTPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
...@@ -692,6 +692,15 @@ class BertGenerationEncoder: ...@@ -692,6 +692,15 @@ class BertGenerationEncoder:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BertGenerationPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def load_tf_weights_in_bert_generation(*args, **kwargs): def load_tf_weights_in_bert_generation(*args, **kwargs):
requires_backends(load_tf_weights_in_bert_generation, ["torch"]) requires_backends(load_tf_weights_in_bert_generation, ["torch"])
...@@ -833,6 +842,15 @@ class BigBirdPegasusModel: ...@@ -833,6 +842,15 @@ class BigBirdPegasusModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BigBirdPegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -863,6 +881,15 @@ class BlenderbotModel: ...@@ -863,6 +881,15 @@ class BlenderbotModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -893,6 +920,15 @@ class BlenderbotSmallModel: ...@@ -893,6 +920,15 @@ class BlenderbotSmallModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BlenderbotSmallPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -1610,6 +1646,15 @@ class FunnelModel: ...@@ -1610,6 +1646,15 @@ class FunnelModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class FunnelPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def load_tf_weights_in_funnel(*args, **kwargs): def load_tf_weights_in_funnel(*args, **kwargs):
requires_backends(load_tf_weights_in_funnel, ["torch"]) requires_backends(load_tf_weights_in_funnel, ["torch"])
...@@ -1840,6 +1885,15 @@ class LayoutLMModel: ...@@ -1840,6 +1885,15 @@ class LayoutLMModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LayoutLMPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -1879,6 +1933,15 @@ class LEDModel: ...@@ -1879,6 +1933,15 @@ class LEDModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LEDPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -1936,6 +1999,15 @@ class LongformerModel: ...@@ -1936,6 +1999,15 @@ class LongformerModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LongformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LongformerSelfAttention: class LongformerSelfAttention:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -2045,6 +2117,15 @@ class M2M100Model: ...@@ -2045,6 +2117,15 @@ class M2M100Model:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class M2M100PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MarianForCausalLM: class MarianForCausalLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -2117,6 +2198,15 @@ class MBartModel: ...@@ -2117,6 +2198,15 @@ class MBartModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class MBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2193,6 +2283,15 @@ class MegatronBertModel: ...@@ -2193,6 +2283,15 @@ class MegatronBertModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class MegatronBertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MMBTForClassification: class MMBTForClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -2474,6 +2573,15 @@ class PegasusModel: ...@@ -2474,6 +2573,15 @@ class PegasusModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class PegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2532,6 +2640,15 @@ class RagModel: ...@@ -2532,6 +2640,15 @@ class RagModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class RagPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class RagSequenceForGeneration: class RagSequenceForGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
...@@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead: ...@@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class ReformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2687,6 +2813,15 @@ class RobertaModel: ...@@ -2687,6 +2813,15 @@ class RobertaModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class RobertaPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2792,6 +2927,15 @@ class Speech2TextModel: ...@@ -2792,6 +2927,15 @@ class Speech2TextModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class Speech2TextPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2945,6 +3089,15 @@ class TapasModel: ...@@ -2945,6 +3089,15 @@ class TapasModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class TapasPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -431,6 +431,15 @@ class TFBlenderbotModel: ...@@ -431,6 +431,15 @@ class TFBlenderbotModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFBlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFBlenderbotSmallForConditionalGeneration: class TFBlenderbotSmallForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -449,6 +458,15 @@ class TFBlenderbotSmallModel: ...@@ -449,6 +458,15 @@ class TFBlenderbotSmallModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFBlenderbotSmallPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -845,6 +863,15 @@ class TFFlaubertModel: ...@@ -845,6 +863,15 @@ class TFFlaubertModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFFlaubertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFFlaubertWithLMHeadModel: class TFFlaubertWithLMHeadModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -925,6 +952,15 @@ class TFFunnelModel: ...@@ -925,6 +952,15 @@ class TFFunnelModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFFunnelPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -1062,6 +1098,15 @@ class TFLongformerModel: ...@@ -1062,6 +1098,15 @@ class TFLongformerModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFLongformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFLongformerSelfAttention: class TFLongformerSelfAttention:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -1121,6 +1166,15 @@ class TFMarianMTModel: ...@@ -1121,6 +1166,15 @@ class TFMarianMTModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFMarianPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFMBartForConditionalGeneration: class TFMBartForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -1139,6 +1193,15 @@ class TFMBartModel: ...@@ -1139,6 +1193,15 @@ class TFMBartModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFMBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -1389,6 +1452,15 @@ class TFPegasusModel: ...@@ -1389,6 +1452,15 @@ class TFPegasusModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFPegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRagModel: class TFRagModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -1398,6 +1470,15 @@ class TFRagModel: ...@@ -1398,6 +1470,15 @@ class TFRagModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFRagPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRagSequenceForGeneration: class TFRagSequenceForGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
......
...@@ -30,3 +30,12 @@ class DetrModel: ...@@ -30,3 +30,12 @@ class DetrModel:
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"]) requires_backends(cls, ["timm", "vision"])
class DetrPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}Model", "{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
] ]
) )
{% endif -%} {% endif -%}
...@@ -120,6 +121,7 @@ ...@@ -120,6 +121,7 @@
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% endif -%} {% endif -%}
# End. # End.
......
...@@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers" ...@@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_TESTS = "tests" PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source" PATH_TO_DOC = "docs/source"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"T5Stack",
"TFDPRSpanPredictor",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be. # Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule. # Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested # models to ignore for not tested
"BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model.
...@@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [ ...@@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [
"PegasusEncoder", # Building part of bigger (tested) model. "PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model. "PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model. "ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder. "ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model. "TFDPREncoder", # Building part of bigger (tested) model.
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix "TFRobertaForMultipleChoice", # TODO: fix
"SeparableConv1D", # Building part of bigger (tested) model. "SeparableConv1D", # Building part of bigger (tested) model.
...@@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ ...@@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule. # should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
...@@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"DetrForSegmentation", "DetrForSegmentation",
"DPRReader", "DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"LukeForEntityClassification", "LukeForEntityClassification",
...@@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"RagModel", "RagModel",
"RagSequenceForGeneration", "RagSequenceForGeneration",
"RagTokenForGeneration", "RagTokenForGeneration",
"T5Stack",
"TFDPRReader", "TFDPRReader",
"TFDPRSpanPredictor",
"TFGPT2DoubleHeadsModel", "TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
"TFRagModel", "TFRagModel",
...@@ -173,12 +174,12 @@ def get_model_modules(): ...@@ -173,12 +174,12 @@ def get_model_modules():
return modules return modules
def get_models(module): def get_models(module, include_pretrained=False):
"""Get the objects in module that are models.""" """Get the objects in module that are models."""
models = [] models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel) model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module): for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name: if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue continue
attr = getattr(module, attr_name) attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__: if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
...@@ -186,6 +187,36 @@ def get_models(module): ...@@ -186,6 +187,36 @@ def get_models(module):
return models return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
return True
if model.endswith("Decoder"):
return True
return False
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
]
# Remove private models
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
if len(models_not_in_init) > 0:
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the # If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function. # nested list _ignore_files of this function.
def get_model_test_files(): def get_model_test_files():
...@@ -229,6 +260,7 @@ def find_tested_models(test_file): ...@@ -229,6 +260,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file): def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file.""" """Check models defined in module are tested in test_file."""
# XxxPreTrainedModel are not tested
defined_models = get_models(module) defined_models = get_models(module)
tested_models = find_tested_models(test_file) tested_models = find_tested_models(test_file)
if tested_models is None: if tested_models is None:
...@@ -515,6 +547,8 @@ def check_all_objects_are_documented(): ...@@ -515,6 +547,8 @@ def check_all_objects_are_documented():
def check_repo_quality(): def check_repo_quality():
"""Check all models are properly tested and documented.""" """Check all models are properly tested and documented."""
print("Checking all models are public.")
check_models_are_in_init()
print("Checking all models are properly tested.") print("Checking all models are properly tested.")
check_all_decorator_order() check_all_decorator_order()
check_all_models_are_tested() check_all_models_are_tested()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment