Unverified Commit 9eda6b52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
parent 53c60bab
......@@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available():
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]
)
else:
......@@ -570,6 +571,7 @@ if is_torch_available():
[
"BertGenerationDecoder",
"BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation",
]
)
......@@ -597,6 +599,7 @@ if is_torch_available():
"BigBirdPegasusForQuestionAnswering",
"BigBirdPegasusForSequenceClassification",
"BigBirdPegasusModel",
"BigBirdPegasusPreTrainedModel",
]
)
_import_structure["models.blenderbot"].extend(
......@@ -605,6 +608,7 @@ if is_torch_available():
"BlenderbotForCausalLM",
"BlenderbotForConditionalGeneration",
"BlenderbotModel",
"BlenderbotPreTrainedModel",
]
)
_import_structure["models.blenderbot_small"].extend(
......@@ -613,6 +617,7 @@ if is_torch_available():
"BlenderbotSmallForCausalLM",
"BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel",
"BlenderbotSmallPreTrainedModel",
]
)
_import_structure["models.camembert"].extend(
......@@ -754,6 +759,7 @@ if is_torch_available():
"FunnelForSequenceClassification",
"FunnelForTokenClassification",
"FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel",
]
)
......@@ -805,6 +811,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
)
_import_structure["models.led"].extend(
......@@ -814,6 +821,7 @@ if is_torch_available():
"LEDForQuestionAnswering",
"LEDForSequenceClassification",
"LEDModel",
"LEDPreTrainedModel",
]
)
_import_structure["models.longformer"].extend(
......@@ -825,6 +833,7 @@ if is_torch_available():
"LongformerForSequenceClassification",
"LongformerForTokenClassification",
"LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention",
]
)
......@@ -854,6 +863,7 @@ if is_torch_available():
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration",
"M2M100Model",
"M2M100PreTrainedModel",
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
......@@ -864,6 +874,7 @@ if is_torch_available():
"MBartForQuestionAnswering",
"MBartForSequenceClassification",
"MBartModel",
"MBartPreTrainedModel",
]
)
_import_structure["models.megatron_bert"].extend(
......@@ -878,6 +889,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification",
"MegatronBertModel",
"MegatronBertPreTrainedModel",
]
)
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
......@@ -923,7 +935,7 @@ if is_torch_available():
]
)
_import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
)
_import_structure["models.prophetnet"].extend(
[
......@@ -936,7 +948,9 @@ if is_torch_available():
"ProphetNetPreTrainedModel",
]
)
_import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"])
_import_structure["models.rag"].extend(
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
)
_import_structure["models.reformer"].extend(
[
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -947,6 +961,7 @@ if is_torch_available():
"ReformerLayer",
"ReformerModel",
"ReformerModelWithLMHead",
"ReformerPreTrainedModel",
]
)
_import_structure["models.retribert"].extend(
......@@ -962,6 +977,7 @@ if is_torch_available():
"RobertaForSequenceClassification",
"RobertaForTokenClassification",
"RobertaModel",
"RobertaPreTrainedModel",
]
)
_import_structure["models.roformer"].extend(
......@@ -984,6 +1000,7 @@ if is_torch_available():
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
"Speech2TextModel",
"Speech2TextPreTrainedModel",
]
)
_import_structure["models.squeezebert"].extend(
......@@ -1016,6 +1033,7 @@ if is_torch_available():
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
]
)
_import_structure["models.transfo_xl"].extend(
......@@ -1197,9 +1215,11 @@ if is_tf_available():
"TFBertPreTrainedModel",
]
)
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
_import_structure["models.blenderbot"].extend(
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
)
_import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
)
_import_structure["models.camembert"].extend(
[
......@@ -1281,6 +1301,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification",
"TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel",
]
)
......@@ -1295,6 +1316,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification",
"TFFunnelModel",
"TFFunnelPreTrainedModel",
]
)
_import_structure["models.gpt2"].extend(
......@@ -1329,6 +1351,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification",
"TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention",
]
)
......@@ -1342,8 +1365,10 @@ if is_tf_available():
"TFLxmertVisualFeatureEncoder",
]
)
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
_import_structure["models.mbart"].extend(
["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
)
_import_structure["models.mobilebert"].extend(
[
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1384,10 +1409,13 @@ if is_tf_available():
"TFOpenAIGPTPreTrainedModel",
]
)
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
_import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
_import_structure["models.rag"].extend(
[
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
......@@ -1538,6 +1566,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
"FlaxBartPreTrainedModel",
]
)
_import_structure["models.bert"].extend(
......@@ -1570,7 +1599,9 @@ if is_flax_available():
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
......@@ -1585,7 +1616,7 @@ if is_flax_available():
"FlaxElectraPreTrainedModel",
]
)
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"])
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForMaskedLM",
......@@ -1597,8 +1628,8 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel",
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
else:
from .utils import dummy_flax_objects
......@@ -1949,6 +1980,7 @@ if TYPE_CHECKING:
DetrForObjectDetection,
DetrForSegmentation,
DetrModel,
DetrPreTrainedModel,
)
else:
from .utils.dummy_timm_objects import *
......@@ -2074,6 +2106,7 @@ if TYPE_CHECKING:
from .models.bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation,
)
from .models.big_bird import (
......@@ -2097,18 +2130,21 @@ if TYPE_CHECKING:
BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel,
BigBirdPegasusPreTrainedModel,
)
from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM,
BlenderbotForConditionalGeneration,
BlenderbotModel,
BlenderbotPreTrainedModel,
)
from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
)
from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2226,6 +2262,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification,
FunnelForTokenClassification,
FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel,
)
from .models.gpt2 import (
......@@ -2267,6 +2304,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
LayoutLMPreTrainedModel,
)
from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2274,6 +2312,7 @@ if TYPE_CHECKING:
LEDForQuestionAnswering,
LEDForSequenceClassification,
LEDModel,
LEDPreTrainedModel,
)
from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2283,6 +2322,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention,
)
from .models.luke import (
......@@ -2302,7 +2342,12 @@ if TYPE_CHECKING:
LxmertVisualFeatureEncoder,
LxmertXLayer,
)
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model
from .models.m2m_100 import (
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
M2M100ForConditionalGeneration,
M2M100Model,
M2M100PreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.mbart import (
MBartForCausalLM,
......@@ -2310,6 +2355,7 @@ if TYPE_CHECKING:
MBartForQuestionAnswering,
MBartForSequenceClassification,
MBartModel,
MBartPreTrainedModel,
)
from .models.megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2322,6 +2368,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
MegatronBertPreTrainedModel,
)
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import (
......@@ -2359,7 +2406,12 @@ if TYPE_CHECKING:
OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt,
)
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
from .models.pegasus import (
PegasusForCausalLM,
PegasusForConditionalGeneration,
PegasusModel,
PegasusPreTrainedModel,
)
from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder,
......@@ -2369,7 +2421,7 @@ if TYPE_CHECKING:
ProphetNetModel,
ProphetNetPreTrainedModel,
)
from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention,
......@@ -2379,6 +2431,7 @@ if TYPE_CHECKING:
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
ReformerPreTrainedModel,
)
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
from .models.roberta import (
......@@ -2390,6 +2443,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
RobertaPreTrainedModel,
)
from .models.roformer import (
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2408,6 +2462,7 @@ if TYPE_CHECKING:
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration,
Speech2TextModel,
Speech2TextPreTrainedModel,
)
from .models.squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2434,6 +2489,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
)
from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2600,8 +2656,16 @@ if TYPE_CHECKING:
TFBertModel,
TFBertPreTrainedModel,
)
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .models.blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
from .models.blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM,
......@@ -2669,6 +2733,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel,
)
from .models.funnel import (
......@@ -2681,6 +2746,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
TFFunnelPreTrainedModel,
)
from .models.gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
......@@ -2700,6 +2766,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention,
)
from .models.lxmert import (
......@@ -2710,8 +2777,8 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder,
)
from .models.marian import TFMarianModel, TFMarianMTModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
......@@ -2746,8 +2813,8 @@ if TYPE_CHECKING:
TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel,
)
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM,
......@@ -2878,6 +2945,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
FlaxBartPreTrainedModel,
)
from .models.bert import (
FlaxBertForMaskedLM,
......@@ -2900,7 +2968,14 @@ if TYPE_CHECKING:
FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel,
)
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.electra import (
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
......@@ -2911,7 +2986,7 @@ if TYPE_CHECKING:
FlaxElectraModel,
FlaxElectraPreTrainedModel,
)
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
......@@ -2921,8 +2996,8 @@ if TYPE_CHECKING:
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
......
......@@ -55,6 +55,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
"FlaxBartPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -85,6 +86,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
FlaxBartPreTrainedModel,
)
else:
......
......@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_bert_generation"] = [
"BertGenerationDecoder",
"BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation",
]
......@@ -46,6 +47,7 @@ if TYPE_CHECKING:
from .modeling_bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation,
)
......
......@@ -37,7 +37,11 @@ if is_torch_available():
if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
_import_structure["modeling_tf_blenderbot"] = [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -54,7 +58,11 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .modeling_tf_blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
else:
import importlib
......
......@@ -38,6 +38,7 @@ if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -54,7 +55,11 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
else:
import importlib
......
......@@ -52,7 +52,9 @@ if is_flax_available():
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
......@@ -77,7 +79,14 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
else:
......
......@@ -46,6 +46,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification",
"TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel",
]
......@@ -74,6 +75,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel,
)
......
......@@ -41,6 +41,7 @@ if is_torch_available():
"FunnelForSequenceClassification",
"FunnelForTokenClassification",
"FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel",
]
......@@ -55,6 +56,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification",
"TFFunnelModel",
"TFFunnelPreTrainedModel",
]
......@@ -76,6 +78,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification,
FunnelForTokenClassification,
FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel,
)
......@@ -90,6 +93,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
TFFunnelPreTrainedModel,
)
else:
......
......@@ -58,7 +58,7 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
......@@ -90,7 +90,7 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else:
import importlib
......
......@@ -38,6 +38,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
if is_tf_available():
......@@ -66,6 +67,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
LayoutLMPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_layoutlm import (
......
......@@ -38,6 +38,7 @@ if is_torch_available():
"LongformerForSequenceClassification",
"LongformerForTokenClassification",
"LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention",
]
......@@ -50,6 +51,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification",
"TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention",
]
......@@ -70,6 +72,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention,
)
......@@ -82,6 +85,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention,
)
......
......@@ -43,7 +43,7 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
if TYPE_CHECKING:
......@@ -62,7 +62,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
else:
import importlib
......
......@@ -50,7 +50,11 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
_import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -76,7 +80,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
else:
import importlib
......
......@@ -36,6 +36,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification",
"MegatronBertModel",
"MegatronBertPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -53,6 +54,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
MegatronBertPreTrainedModel,
)
else:
......
......@@ -46,7 +46,11 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
_import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -68,7 +72,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
else:
import importlib
......
......@@ -28,10 +28,20 @@ _import_structure = {
}
if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
_import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
_import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING:
......@@ -40,10 +50,15 @@ if TYPE_CHECKING:
from .tokenization_rag import RagTokenizer
if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else:
import importlib
......
......@@ -41,6 +41,7 @@ if is_torch_available():
"ReformerLayer",
"ReformerModel",
"ReformerModelWithLMHead",
"ReformerPreTrainedModel",
]
......@@ -63,6 +64,7 @@ if TYPE_CHECKING:
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
ReformerPreTrainedModel,
)
else:
......
......@@ -45,6 +45,7 @@ if is_torch_available():
"RobertaForSequenceClassification",
"RobertaForTokenClassification",
"RobertaModel",
"RobertaPreTrainedModel",
]
if is_tf_available():
......@@ -89,6 +90,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
RobertaPreTrainedModel,
)
if is_tf_available():
......
......@@ -33,6 +33,7 @@ if is_torch_available():
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
]
......@@ -47,6 +48,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
)
else:
......
......@@ -37,7 +37,11 @@ if is_torch_available():
if is_flax_available():
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"]
_import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
"FlaxViTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
......@@ -54,7 +58,7 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment