Unverified Commit 9eda6b52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
parent 53c60bab
This diff is collapsed.
...@@ -55,6 +55,7 @@ if is_flax_available(): ...@@ -55,6 +55,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering", "FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification", "FlaxBartForSequenceClassification",
"FlaxBartModel", "FlaxBartModel",
"FlaxBartPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -85,6 +86,7 @@ if TYPE_CHECKING: ...@@ -85,6 +86,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering, FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification, FlaxBartForSequenceClassification,
FlaxBartModel, FlaxBartModel,
FlaxBartPreTrainedModel,
) )
else: else:
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_bert_generation"] = [ _import_structure["modeling_bert_generation"] = [
"BertGenerationDecoder", "BertGenerationDecoder",
"BertGenerationEncoder", "BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation", "load_tf_weights_in_bert_generation",
] ]
...@@ -46,6 +47,7 @@ if TYPE_CHECKING: ...@@ -46,6 +47,7 @@ if TYPE_CHECKING:
from .modeling_bert_generation import ( from .modeling_bert_generation import (
BertGenerationDecoder, BertGenerationDecoder,
BertGenerationEncoder, BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation, load_tf_weights_in_bert_generation,
) )
......
...@@ -37,7 +37,11 @@ if is_torch_available(): ...@@ -37,7 +37,11 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"] _import_structure["modeling_tf_blenderbot"] = [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,7 +58,11 @@ if TYPE_CHECKING: ...@@ -54,7 +58,11 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel from .modeling_tf_blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
else: else:
import importlib import importlib
......
...@@ -38,6 +38,7 @@ if is_tf_available(): ...@@ -38,6 +38,7 @@ if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [ _import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel", "TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,7 +55,11 @@ if TYPE_CHECKING: ...@@ -54,7 +55,11 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel from .modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
else: else:
import importlib import importlib
......
...@@ -52,7 +52,9 @@ if is_flax_available(): ...@@ -52,7 +52,9 @@ if is_flax_available():
"FlaxCLIPModel", "FlaxCLIPModel",
"FlaxCLIPPreTrainedModel", "FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
] ]
...@@ -77,7 +79,14 @@ if TYPE_CHECKING: ...@@ -77,7 +79,14 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
else: else:
......
...@@ -46,6 +46,7 @@ if is_tf_available(): ...@@ -46,6 +46,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification", "TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification", "TFFlaubertForTokenClassification",
"TFFlaubertModel", "TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel", "TFFlaubertWithLMHeadModel",
] ]
...@@ -74,6 +75,7 @@ if TYPE_CHECKING: ...@@ -74,6 +75,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification, TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification, TFFlaubertForTokenClassification,
TFFlaubertModel, TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
) )
......
...@@ -41,6 +41,7 @@ if is_torch_available(): ...@@ -41,6 +41,7 @@ if is_torch_available():
"FunnelForSequenceClassification", "FunnelForSequenceClassification",
"FunnelForTokenClassification", "FunnelForTokenClassification",
"FunnelModel", "FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel", "load_tf_weights_in_funnel",
] ]
...@@ -55,6 +56,7 @@ if is_tf_available(): ...@@ -55,6 +56,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification", "TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification", "TFFunnelForTokenClassification",
"TFFunnelModel", "TFFunnelModel",
"TFFunnelPreTrainedModel",
] ]
...@@ -76,6 +78,7 @@ if TYPE_CHECKING: ...@@ -76,6 +78,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification, FunnelForSequenceClassification,
FunnelForTokenClassification, FunnelForTokenClassification,
FunnelModel, FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel, load_tf_weights_in_funnel,
) )
...@@ -90,6 +93,7 @@ if TYPE_CHECKING: ...@@ -90,6 +93,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification, TFFunnelForSequenceClassification,
TFFunnelForTokenClassification, TFFunnelForTokenClassification,
TFFunnelModel, TFFunnelModel,
TFFunnelPreTrainedModel,
) )
else: else:
......
...@@ -58,7 +58,7 @@ if is_tf_available(): ...@@ -58,7 +58,7 @@ if is_tf_available():
] ]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"] _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
...@@ -90,7 +90,7 @@ if TYPE_CHECKING: ...@@ -90,7 +90,7 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else: else:
import importlib import importlib
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
...@@ -66,6 +67,7 @@ if TYPE_CHECKING: ...@@ -66,6 +67,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
LayoutLMPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
"LongformerForSequenceClassification", "LongformerForSequenceClassification",
"LongformerForTokenClassification", "LongformerForTokenClassification",
"LongformerModel", "LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention", "LongformerSelfAttention",
] ]
...@@ -50,6 +51,7 @@ if is_tf_available(): ...@@ -50,6 +51,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification", "TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification", "TFLongformerForTokenClassification",
"TFLongformerModel", "TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention", "TFLongformerSelfAttention",
] ]
...@@ -70,6 +72,7 @@ if TYPE_CHECKING: ...@@ -70,6 +72,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerModel, LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention, LongformerSelfAttention,
) )
...@@ -82,6 +85,7 @@ if TYPE_CHECKING: ...@@ -82,6 +85,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification, TFLongformerForSequenceClassification,
TFLongformerForTokenClassification, TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
......
...@@ -43,7 +43,7 @@ if is_torch_available(): ...@@ -43,7 +43,7 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -62,7 +62,7 @@ if TYPE_CHECKING: ...@@ -62,7 +62,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
else: else:
import importlib import importlib
......
...@@ -50,7 +50,11 @@ if is_torch_available(): ...@@ -50,7 +50,11 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"] _import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -76,7 +80,7 @@ if TYPE_CHECKING: ...@@ -76,7 +80,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
else: else:
import importlib import importlib
......
...@@ -36,6 +36,7 @@ if is_torch_available(): ...@@ -36,6 +36,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification", "MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification", "MegatronBertForTokenClassification",
"MegatronBertModel", "MegatronBertModel",
"MegatronBertPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -53,6 +54,7 @@ if TYPE_CHECKING: ...@@ -53,6 +54,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification, MegatronBertForSequenceClassification,
MegatronBertForTokenClassification, MegatronBertForTokenClassification,
MegatronBertModel, MegatronBertModel,
MegatronBertPreTrainedModel,
) )
else: else:
......
...@@ -46,7 +46,11 @@ if is_torch_available(): ...@@ -46,7 +46,11 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"] _import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -68,7 +72,7 @@ if TYPE_CHECKING: ...@@ -68,7 +72,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
else: else:
import importlib import importlib
......
...@@ -28,10 +28,20 @@ _import_structure = { ...@@ -28,10 +28,20 @@ _import_structure = {
} }
if is_torch_available(): if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] _import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] _import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -40,10 +50,15 @@ if TYPE_CHECKING: ...@@ -40,10 +50,15 @@ if TYPE_CHECKING:
from .tokenization_rag import RagTokenizer from .tokenization_rag import RagTokenizer
if is_torch_available(): if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available(): if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else: else:
import importlib import importlib
......
...@@ -41,6 +41,7 @@ if is_torch_available(): ...@@ -41,6 +41,7 @@ if is_torch_available():
"ReformerLayer", "ReformerLayer",
"ReformerModel", "ReformerModel",
"ReformerModelWithLMHead", "ReformerModelWithLMHead",
"ReformerPreTrainedModel",
] ]
...@@ -63,6 +64,7 @@ if TYPE_CHECKING: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING:
ReformerLayer, ReformerLayer,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerPreTrainedModel,
) )
else: else:
......
...@@ -45,6 +45,7 @@ if is_torch_available(): ...@@ -45,6 +45,7 @@ if is_torch_available():
"RobertaForSequenceClassification", "RobertaForSequenceClassification",
"RobertaForTokenClassification", "RobertaForTokenClassification",
"RobertaModel", "RobertaModel",
"RobertaPreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
...@@ -89,6 +90,7 @@ if TYPE_CHECKING: ...@@ -89,6 +90,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
RobertaPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():
......
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
"TapasForQuestionAnswering", "TapasForQuestionAnswering",
"TapasForSequenceClassification", "TapasForSequenceClassification",
"TapasModel", "TapasModel",
"TapasPreTrainedModel",
] ]
...@@ -47,6 +48,7 @@ if TYPE_CHECKING: ...@@ -47,6 +48,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasPreTrainedModel,
) )
else: else:
......
...@@ -37,7 +37,11 @@ if is_torch_available(): ...@@ -37,7 +37,11 @@ if is_torch_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"] _import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
"FlaxViTPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
...@@ -54,7 +58,7 @@ if TYPE_CHECKING: ...@@ -54,7 +58,7 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment