Unverified Commit 9eda6b52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
parent 53c60bab
...@@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available(): ...@@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available():
"DetrForObjectDetection", "DetrForObjectDetection",
"DetrForSegmentation", "DetrForSegmentation",
"DetrModel", "DetrModel",
"DetrPreTrainedModel",
] ]
) )
else: else:
...@@ -570,6 +571,7 @@ if is_torch_available(): ...@@ -570,6 +571,7 @@ if is_torch_available():
[ [
"BertGenerationDecoder", "BertGenerationDecoder",
"BertGenerationEncoder", "BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation", "load_tf_weights_in_bert_generation",
] ]
) )
...@@ -597,6 +599,7 @@ if is_torch_available(): ...@@ -597,6 +599,7 @@ if is_torch_available():
"BigBirdPegasusForQuestionAnswering", "BigBirdPegasusForQuestionAnswering",
"BigBirdPegasusForSequenceClassification", "BigBirdPegasusForSequenceClassification",
"BigBirdPegasusModel", "BigBirdPegasusModel",
"BigBirdPegasusPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot"].extend( _import_structure["models.blenderbot"].extend(
...@@ -605,6 +608,7 @@ if is_torch_available(): ...@@ -605,6 +608,7 @@ if is_torch_available():
"BlenderbotForCausalLM", "BlenderbotForCausalLM",
"BlenderbotForConditionalGeneration", "BlenderbotForConditionalGeneration",
"BlenderbotModel", "BlenderbotModel",
"BlenderbotPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
...@@ -613,6 +617,7 @@ if is_torch_available(): ...@@ -613,6 +617,7 @@ if is_torch_available():
"BlenderbotSmallForCausalLM", "BlenderbotSmallForCausalLM",
"BlenderbotSmallForConditionalGeneration", "BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel", "BlenderbotSmallModel",
"BlenderbotSmallPreTrainedModel",
] ]
) )
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
...@@ -754,6 +759,7 @@ if is_torch_available(): ...@@ -754,6 +759,7 @@ if is_torch_available():
"FunnelForSequenceClassification", "FunnelForSequenceClassification",
"FunnelForTokenClassification", "FunnelForTokenClassification",
"FunnelModel", "FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel", "load_tf_weights_in_funnel",
] ]
) )
...@@ -805,6 +811,7 @@ if is_torch_available(): ...@@ -805,6 +811,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel",
] ]
) )
_import_structure["models.led"].extend( _import_structure["models.led"].extend(
...@@ -814,6 +821,7 @@ if is_torch_available(): ...@@ -814,6 +821,7 @@ if is_torch_available():
"LEDForQuestionAnswering", "LEDForQuestionAnswering",
"LEDForSequenceClassification", "LEDForSequenceClassification",
"LEDModel", "LEDModel",
"LEDPreTrainedModel",
] ]
) )
_import_structure["models.longformer"].extend( _import_structure["models.longformer"].extend(
...@@ -825,6 +833,7 @@ if is_torch_available(): ...@@ -825,6 +833,7 @@ if is_torch_available():
"LongformerForSequenceClassification", "LongformerForSequenceClassification",
"LongformerForTokenClassification", "LongformerForTokenClassification",
"LongformerModel", "LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention", "LongformerSelfAttention",
] ]
) )
...@@ -854,6 +863,7 @@ if is_torch_available(): ...@@ -854,6 +863,7 @@ if is_torch_available():
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration", "M2M100ForConditionalGeneration",
"M2M100Model", "M2M100Model",
"M2M100PreTrainedModel",
] ]
) )
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
...@@ -864,6 +874,7 @@ if is_torch_available(): ...@@ -864,6 +874,7 @@ if is_torch_available():
"MBartForQuestionAnswering", "MBartForQuestionAnswering",
"MBartForSequenceClassification", "MBartForSequenceClassification",
"MBartModel", "MBartModel",
"MBartPreTrainedModel",
] ]
) )
_import_structure["models.megatron_bert"].extend( _import_structure["models.megatron_bert"].extend(
...@@ -878,6 +889,7 @@ if is_torch_available(): ...@@ -878,6 +889,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification", "MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification", "MegatronBertForTokenClassification",
"MegatronBertModel", "MegatronBertModel",
"MegatronBertPreTrainedModel",
] ]
) )
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) _import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
...@@ -923,7 +935,7 @@ if is_torch_available(): ...@@ -923,7 +935,7 @@ if is_torch_available():
] ]
) )
_import_structure["models.pegasus"].extend( _import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"] ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
) )
_import_structure["models.prophetnet"].extend( _import_structure["models.prophetnet"].extend(
[ [
...@@ -936,7 +948,9 @@ if is_torch_available(): ...@@ -936,7 +948,9 @@ if is_torch_available():
"ProphetNetPreTrainedModel", "ProphetNetPreTrainedModel",
] ]
) )
_import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]) _import_structure["models.rag"].extend(
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
)
_import_structure["models.reformer"].extend( _import_structure["models.reformer"].extend(
[ [
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -947,6 +961,7 @@ if is_torch_available(): ...@@ -947,6 +961,7 @@ if is_torch_available():
"ReformerLayer", "ReformerLayer",
"ReformerModel", "ReformerModel",
"ReformerModelWithLMHead", "ReformerModelWithLMHead",
"ReformerPreTrainedModel",
] ]
) )
_import_structure["models.retribert"].extend( _import_structure["models.retribert"].extend(
...@@ -962,6 +977,7 @@ if is_torch_available(): ...@@ -962,6 +977,7 @@ if is_torch_available():
"RobertaForSequenceClassification", "RobertaForSequenceClassification",
"RobertaForTokenClassification", "RobertaForTokenClassification",
"RobertaModel", "RobertaModel",
"RobertaPreTrainedModel",
] ]
) )
_import_structure["models.roformer"].extend( _import_structure["models.roformer"].extend(
...@@ -984,6 +1000,7 @@ if is_torch_available(): ...@@ -984,6 +1000,7 @@ if is_torch_available():
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration", "Speech2TextForConditionalGeneration",
"Speech2TextModel", "Speech2TextModel",
"Speech2TextPreTrainedModel",
] ]
) )
_import_structure["models.squeezebert"].extend( _import_structure["models.squeezebert"].extend(
...@@ -1016,6 +1033,7 @@ if is_torch_available(): ...@@ -1016,6 +1033,7 @@ if is_torch_available():
"TapasForQuestionAnswering", "TapasForQuestionAnswering",
"TapasForSequenceClassification", "TapasForSequenceClassification",
"TapasModel", "TapasModel",
"TapasPreTrainedModel",
] ]
) )
_import_structure["models.transfo_xl"].extend( _import_structure["models.transfo_xl"].extend(
...@@ -1197,9 +1215,11 @@ if is_tf_available(): ...@@ -1197,9 +1215,11 @@ if is_tf_available():
"TFBertPreTrainedModel", "TFBertPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]) _import_structure["models.blenderbot"].extend(
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
)
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"] ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
) )
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
[ [
...@@ -1281,6 +1301,7 @@ if is_tf_available(): ...@@ -1281,6 +1301,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification", "TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification", "TFFlaubertForTokenClassification",
"TFFlaubertModel", "TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel", "TFFlaubertWithLMHeadModel",
] ]
) )
...@@ -1295,6 +1316,7 @@ if is_tf_available(): ...@@ -1295,6 +1316,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification", "TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification", "TFFunnelForTokenClassification",
"TFFunnelModel", "TFFunnelModel",
"TFFunnelPreTrainedModel",
] ]
) )
_import_structure["models.gpt2"].extend( _import_structure["models.gpt2"].extend(
...@@ -1329,6 +1351,7 @@ if is_tf_available(): ...@@ -1329,6 +1351,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification", "TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification", "TFLongformerForTokenClassification",
"TFLongformerModel", "TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention", "TFLongformerSelfAttention",
] ]
) )
...@@ -1342,8 +1365,10 @@ if is_tf_available(): ...@@ -1342,8 +1365,10 @@ if is_tf_available():
"TFLxmertVisualFeatureEncoder", "TFLxmertVisualFeatureEncoder",
] ]
) )
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"]) _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"]) _import_structure["models.mbart"].extend(
["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
)
_import_structure["models.mobilebert"].extend( _import_structure["models.mobilebert"].extend(
[ [
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1384,10 +1409,13 @@ if is_tf_available(): ...@@ -1384,10 +1409,13 @@ if is_tf_available():
"TFOpenAIGPTPreTrainedModel", "TFOpenAIGPTPreTrainedModel",
] ]
) )
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"]) _import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
_import_structure["models.rag"].extend( _import_structure["models.rag"].extend(
[ [
"TFRagModel", "TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration", "TFRagSequenceForGeneration",
"TFRagTokenForGeneration", "TFRagTokenForGeneration",
] ]
...@@ -1538,6 +1566,7 @@ if is_flax_available(): ...@@ -1538,6 +1566,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering", "FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification", "FlaxBartForSequenceClassification",
"FlaxBartModel", "FlaxBartModel",
"FlaxBartPreTrainedModel",
] ]
) )
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
...@@ -1570,7 +1599,9 @@ if is_flax_available(): ...@@ -1570,7 +1599,9 @@ if is_flax_available():
"FlaxCLIPModel", "FlaxCLIPModel",
"FlaxCLIPPreTrainedModel", "FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
] ]
) )
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
...@@ -1585,7 +1616,7 @@ if is_flax_available(): ...@@ -1585,7 +1616,7 @@ if is_flax_available():
"FlaxElectraPreTrainedModel", "FlaxElectraPreTrainedModel",
] ]
) )
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]) _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
...@@ -1597,8 +1628,8 @@ if is_flax_available(): ...@@ -1597,8 +1628,8 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel", "FlaxRobertaPreTrainedModel",
] ]
) )
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"]) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
else: else:
from .utils import dummy_flax_objects from .utils import dummy_flax_objects
...@@ -1949,6 +1980,7 @@ if TYPE_CHECKING: ...@@ -1949,6 +1980,7 @@ if TYPE_CHECKING:
DetrForObjectDetection, DetrForObjectDetection,
DetrForSegmentation, DetrForSegmentation,
DetrModel, DetrModel,
DetrPreTrainedModel,
) )
else: else:
from .utils.dummy_timm_objects import * from .utils.dummy_timm_objects import *
...@@ -2074,6 +2106,7 @@ if TYPE_CHECKING: ...@@ -2074,6 +2106,7 @@ if TYPE_CHECKING:
from .models.bert_generation import ( from .models.bert_generation import (
BertGenerationDecoder, BertGenerationDecoder,
BertGenerationEncoder, BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation, load_tf_weights_in_bert_generation,
) )
from .models.big_bird import ( from .models.big_bird import (
...@@ -2097,18 +2130,21 @@ if TYPE_CHECKING: ...@@ -2097,18 +2130,21 @@ if TYPE_CHECKING:
BigBirdPegasusForQuestionAnswering, BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification, BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel, BigBirdPegasusModel,
BigBirdPegasusPreTrainedModel,
) )
from .models.blenderbot import ( from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM, BlenderbotForCausalLM,
BlenderbotForConditionalGeneration, BlenderbotForConditionalGeneration,
BlenderbotModel, BlenderbotModel,
BlenderbotPreTrainedModel,
) )
from .models.blenderbot_small import ( from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForCausalLM, BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration, BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel, BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
) )
from .models.camembert import ( from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2226,6 +2262,7 @@ if TYPE_CHECKING: ...@@ -2226,6 +2262,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification, FunnelForSequenceClassification,
FunnelForTokenClassification, FunnelForTokenClassification,
FunnelModel, FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel, load_tf_weights_in_funnel,
) )
from .models.gpt2 import ( from .models.gpt2 import (
...@@ -2267,6 +2304,7 @@ if TYPE_CHECKING: ...@@ -2267,6 +2304,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
LayoutLMPreTrainedModel,
) )
from .models.led import ( from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST, LED_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2274,6 +2312,7 @@ if TYPE_CHECKING: ...@@ -2274,6 +2312,7 @@ if TYPE_CHECKING:
LEDForQuestionAnswering, LEDForQuestionAnswering,
LEDForSequenceClassification, LEDForSequenceClassification,
LEDModel, LEDModel,
LEDPreTrainedModel,
) )
from .models.longformer import ( from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2283,6 +2322,7 @@ if TYPE_CHECKING: ...@@ -2283,6 +2322,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerModel, LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention, LongformerSelfAttention,
) )
from .models.luke import ( from .models.luke import (
...@@ -2302,7 +2342,12 @@ if TYPE_CHECKING: ...@@ -2302,7 +2342,12 @@ if TYPE_CHECKING:
LxmertVisualFeatureEncoder, LxmertVisualFeatureEncoder,
LxmertXLayer, LxmertXLayer,
) )
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model from .models.m2m_100 import (
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
M2M100ForConditionalGeneration,
M2M100Model,
M2M100PreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.mbart import ( from .models.mbart import (
MBartForCausalLM, MBartForCausalLM,
...@@ -2310,6 +2355,7 @@ if TYPE_CHECKING: ...@@ -2310,6 +2355,7 @@ if TYPE_CHECKING:
MBartForQuestionAnswering, MBartForQuestionAnswering,
MBartForSequenceClassification, MBartForSequenceClassification,
MBartModel, MBartModel,
MBartPreTrainedModel,
) )
from .models.megatron_bert import ( from .models.megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2322,6 +2368,7 @@ if TYPE_CHECKING: ...@@ -2322,6 +2368,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification, MegatronBertForSequenceClassification,
MegatronBertForTokenClassification, MegatronBertForTokenClassification,
MegatronBertModel, MegatronBertModel,
MegatronBertPreTrainedModel,
) )
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import ( from .models.mobilebert import (
...@@ -2359,7 +2406,12 @@ if TYPE_CHECKING: ...@@ -2359,7 +2406,12 @@ if TYPE_CHECKING:
OpenAIGPTPreTrainedModel, OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt, load_tf_weights_in_openai_gpt,
) )
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel from .models.pegasus import (
PegasusForCausalLM,
PegasusForConditionalGeneration,
PegasusModel,
PegasusPreTrainedModel,
)
from .models.prophetnet import ( from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder, ProphetNetDecoder,
...@@ -2369,7 +2421,7 @@ if TYPE_CHECKING: ...@@ -2369,7 +2421,7 @@ if TYPE_CHECKING:
ProphetNetModel, ProphetNetModel,
ProphetNetPreTrainedModel, ProphetNetPreTrainedModel,
) )
from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.reformer import ( from .models.reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention, ReformerAttention,
...@@ -2379,6 +2431,7 @@ if TYPE_CHECKING: ...@@ -2379,6 +2431,7 @@ if TYPE_CHECKING:
ReformerLayer, ReformerLayer,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerPreTrainedModel,
) )
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
from .models.roberta import ( from .models.roberta import (
...@@ -2390,6 +2443,7 @@ if TYPE_CHECKING: ...@@ -2390,6 +2443,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
RobertaPreTrainedModel,
) )
from .models.roformer import ( from .models.roformer import (
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2408,6 +2462,7 @@ if TYPE_CHECKING: ...@@ -2408,6 +2462,7 @@ if TYPE_CHECKING:
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
Speech2TextModel, Speech2TextModel,
Speech2TextPreTrainedModel,
) )
from .models.squeezebert import ( from .models.squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2434,6 +2489,7 @@ if TYPE_CHECKING: ...@@ -2434,6 +2489,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasPreTrainedModel,
) )
from .models.transfo_xl import ( from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2600,8 +2656,16 @@ if TYPE_CHECKING: ...@@ -2600,8 +2656,16 @@ if TYPE_CHECKING:
TFBertModel, TFBertModel,
TFBertPreTrainedModel, TFBertPreTrainedModel,
) )
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel from .models.blenderbot import (
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
from .models.blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.camembert import ( from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
...@@ -2669,6 +2733,7 @@ if TYPE_CHECKING: ...@@ -2669,6 +2733,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification, TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification, TFFlaubertForTokenClassification,
TFFlaubertModel, TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
) )
from .models.funnel import ( from .models.funnel import (
...@@ -2681,6 +2746,7 @@ if TYPE_CHECKING: ...@@ -2681,6 +2746,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification, TFFunnelForSequenceClassification,
TFFunnelForTokenClassification, TFFunnelForTokenClassification,
TFFunnelModel, TFFunnelModel,
TFFunnelPreTrainedModel,
) )
from .models.gpt2 import ( from .models.gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
...@@ -2700,6 +2766,7 @@ if TYPE_CHECKING: ...@@ -2700,6 +2766,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification, TFLongformerForSequenceClassification,
TFLongformerForTokenClassification, TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
from .models.lxmert import ( from .models.lxmert import (
...@@ -2710,8 +2777,8 @@ if TYPE_CHECKING: ...@@ -2710,8 +2777,8 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel, TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder, TFLxmertVisualFeatureEncoder,
) )
from .models.marian import TFMarianModel, TFMarianMTModel from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
from .models.mobilebert import ( from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
...@@ -2746,8 +2813,8 @@ if TYPE_CHECKING: ...@@ -2746,8 +2813,8 @@ if TYPE_CHECKING:
TFOpenAIGPTModel, TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel, TFOpenAIGPTPreTrainedModel,
) )
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.roberta import ( from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
...@@ -2878,6 +2945,7 @@ if TYPE_CHECKING: ...@@ -2878,6 +2945,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering, FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification, FlaxBartForSequenceClassification,
FlaxBartModel, FlaxBartModel,
FlaxBartPreTrainedModel,
) )
from .models.bert import ( from .models.bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
...@@ -2900,7 +2968,14 @@ if TYPE_CHECKING: ...@@ -2900,7 +2968,14 @@ if TYPE_CHECKING:
FlaxBigBirdModel, FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel, FlaxBigBirdPreTrainedModel,
) )
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.electra import ( from .models.electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
...@@ -2911,7 +2986,7 @@ if TYPE_CHECKING: ...@@ -2911,7 +2986,7 @@ if TYPE_CHECKING:
FlaxElectraModel, FlaxElectraModel,
FlaxElectraPreTrainedModel, FlaxElectraPreTrainedModel,
) )
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
...@@ -2921,8 +2996,8 @@ if TYPE_CHECKING: ...@@ -2921,8 +2996,8 @@ if TYPE_CHECKING:
FlaxRobertaModel, FlaxRobertaModel,
FlaxRobertaPreTrainedModel, FlaxRobertaPreTrainedModel,
) )
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else: else:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them. # They will raise an import error if the user tries to instantiate / use them.
......
...@@ -55,6 +55,7 @@ if is_flax_available(): ...@@ -55,6 +55,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering", "FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification", "FlaxBartForSequenceClassification",
"FlaxBartModel", "FlaxBartModel",
"FlaxBartPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -85,6 +86,7 @@ if TYPE_CHECKING: ...@@ -85,6 +86,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering, FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification, FlaxBartForSequenceClassification,
FlaxBartModel, FlaxBartModel,
FlaxBartPreTrainedModel,
) )
else: else:
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_bert_generation"] = [ _import_structure["modeling_bert_generation"] = [
"BertGenerationDecoder", "BertGenerationDecoder",
"BertGenerationEncoder", "BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation", "load_tf_weights_in_bert_generation",
] ]
...@@ -46,6 +47,7 @@ if TYPE_CHECKING: ...@@ -46,6 +47,7 @@ if TYPE_CHECKING:
from .modeling_bert_generation import ( from .modeling_bert_generation import (
BertGenerationDecoder, BertGenerationDecoder,
BertGenerationEncoder, BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation, load_tf_weights_in_bert_generation,
) )
......
...@@ -37,7 +37,11 @@ if is_torch_available(): ...@@ -37,7 +37,11 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"] _import_structure["modeling_tf_blenderbot"] = [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,7 +58,11 @@ if TYPE_CHECKING: ...@@ -54,7 +58,11 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel from .modeling_tf_blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
else: else:
import importlib import importlib
......
...@@ -38,6 +38,7 @@ if is_tf_available(): ...@@ -38,6 +38,7 @@ if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [ _import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel", "TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,7 +55,11 @@ if TYPE_CHECKING: ...@@ -54,7 +55,11 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel from .modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
else: else:
import importlib import importlib
......
...@@ -52,7 +52,9 @@ if is_flax_available(): ...@@ -52,7 +52,9 @@ if is_flax_available():
"FlaxCLIPModel", "FlaxCLIPModel",
"FlaxCLIPPreTrainedModel", "FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
] ]
...@@ -77,7 +79,14 @@ if TYPE_CHECKING: ...@@ -77,7 +79,14 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
else: else:
......
...@@ -46,6 +46,7 @@ if is_tf_available(): ...@@ -46,6 +46,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification", "TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification", "TFFlaubertForTokenClassification",
"TFFlaubertModel", "TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel", "TFFlaubertWithLMHeadModel",
] ]
...@@ -74,6 +75,7 @@ if TYPE_CHECKING: ...@@ -74,6 +75,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification, TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification, TFFlaubertForTokenClassification,
TFFlaubertModel, TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
) )
......
...@@ -41,6 +41,7 @@ if is_torch_available(): ...@@ -41,6 +41,7 @@ if is_torch_available():
"FunnelForSequenceClassification", "FunnelForSequenceClassification",
"FunnelForTokenClassification", "FunnelForTokenClassification",
"FunnelModel", "FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel", "load_tf_weights_in_funnel",
] ]
...@@ -55,6 +56,7 @@ if is_tf_available(): ...@@ -55,6 +56,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification", "TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification", "TFFunnelForTokenClassification",
"TFFunnelModel", "TFFunnelModel",
"TFFunnelPreTrainedModel",
] ]
...@@ -76,6 +78,7 @@ if TYPE_CHECKING: ...@@ -76,6 +78,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification, FunnelForSequenceClassification,
FunnelForTokenClassification, FunnelForTokenClassification,
FunnelModel, FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel, load_tf_weights_in_funnel,
) )
...@@ -90,6 +93,7 @@ if TYPE_CHECKING: ...@@ -90,6 +93,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification, TFFunnelForSequenceClassification,
TFFunnelForTokenClassification, TFFunnelForTokenClassification,
TFFunnelModel, TFFunnelModel,
TFFunnelPreTrainedModel,
) )
else: else:
......
...@@ -58,7 +58,7 @@ if is_tf_available(): ...@@ -58,7 +58,7 @@ if is_tf_available():
] ]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"] _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
...@@ -90,7 +90,7 @@ if TYPE_CHECKING: ...@@ -90,7 +90,7 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else: else:
import importlib import importlib
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
...@@ -66,6 +67,7 @@ if TYPE_CHECKING: ...@@ -66,6 +67,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
LayoutLMPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
"LongformerForSequenceClassification", "LongformerForSequenceClassification",
"LongformerForTokenClassification", "LongformerForTokenClassification",
"LongformerModel", "LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention", "LongformerSelfAttention",
] ]
...@@ -50,6 +51,7 @@ if is_tf_available(): ...@@ -50,6 +51,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification", "TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification", "TFLongformerForTokenClassification",
"TFLongformerModel", "TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention", "TFLongformerSelfAttention",
] ]
...@@ -70,6 +72,7 @@ if TYPE_CHECKING: ...@@ -70,6 +72,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerModel, LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention, LongformerSelfAttention,
) )
...@@ -82,6 +85,7 @@ if TYPE_CHECKING: ...@@ -82,6 +85,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification, TFLongformerForSequenceClassification,
TFLongformerForTokenClassification, TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
......
...@@ -43,7 +43,7 @@ if is_torch_available(): ...@@ -43,7 +43,7 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -62,7 +62,7 @@ if TYPE_CHECKING: ...@@ -62,7 +62,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
else: else:
import importlib import importlib
......
...@@ -50,7 +50,11 @@ if is_torch_available(): ...@@ -50,7 +50,11 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"] _import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -76,7 +80,7 @@ if TYPE_CHECKING: ...@@ -76,7 +80,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
else: else:
import importlib import importlib
......
...@@ -36,6 +36,7 @@ if is_torch_available(): ...@@ -36,6 +36,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification", "MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification", "MegatronBertForTokenClassification",
"MegatronBertModel", "MegatronBertModel",
"MegatronBertPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -53,6 +54,7 @@ if TYPE_CHECKING: ...@@ -53,6 +54,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification, MegatronBertForSequenceClassification,
MegatronBertForTokenClassification, MegatronBertForTokenClassification,
MegatronBertModel, MegatronBertModel,
MegatronBertPreTrainedModel,
) )
else: else:
......
...@@ -46,7 +46,11 @@ if is_torch_available(): ...@@ -46,7 +46,11 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"] _import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -68,7 +72,7 @@ if TYPE_CHECKING: ...@@ -68,7 +72,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
else: else:
import importlib import importlib
......
...@@ -28,10 +28,20 @@ _import_structure = { ...@@ -28,10 +28,20 @@ _import_structure = {
} }
if is_torch_available(): if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] _import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] _import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -40,10 +50,15 @@ if TYPE_CHECKING: ...@@ -40,10 +50,15 @@ if TYPE_CHECKING:
from .tokenization_rag import RagTokenizer from .tokenization_rag import RagTokenizer
if is_torch_available(): if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available(): if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else: else:
import importlib import importlib
......
...@@ -41,6 +41,7 @@ if is_torch_available(): ...@@ -41,6 +41,7 @@ if is_torch_available():
"ReformerLayer", "ReformerLayer",
"ReformerModel", "ReformerModel",
"ReformerModelWithLMHead", "ReformerModelWithLMHead",
"ReformerPreTrainedModel",
] ]
...@@ -63,6 +64,7 @@ if TYPE_CHECKING: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING:
ReformerLayer, ReformerLayer,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerPreTrainedModel,
) )
else: else:
......
...@@ -45,6 +45,7 @@ if is_torch_available(): ...@@ -45,6 +45,7 @@ if is_torch_available():
"RobertaForSequenceClassification", "RobertaForSequenceClassification",
"RobertaForTokenClassification", "RobertaForTokenClassification",
"RobertaModel", "RobertaModel",
"RobertaPreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
...@@ -89,6 +90,7 @@ if TYPE_CHECKING: ...@@ -89,6 +90,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
RobertaPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():
......
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
"TapasForQuestionAnswering", "TapasForQuestionAnswering",
"TapasForSequenceClassification", "TapasForSequenceClassification",
"TapasModel", "TapasModel",
"TapasPreTrainedModel",
] ]
...@@ -47,6 +48,7 @@ if TYPE_CHECKING: ...@@ -47,6 +48,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasPreTrainedModel,
) )
else: else:
......
...@@ -37,7 +37,11 @@ if is_torch_available(): ...@@ -37,7 +37,11 @@ if is_torch_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"] _import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
"FlaxViTPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
...@@ -54,7 +58,7 @@ if TYPE_CHECKING: ...@@ -54,7 +58,7 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment