"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "df5d9c3551a6405feb697a1cad903dddffa04bfe"
Unverified Commit 9eda6b52 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
parent 53c60bab
This diff is collapsed.
......@@ -55,6 +55,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
"FlaxBartPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -85,6 +86,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
FlaxBartPreTrainedModel,
)
else:
......
......@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_bert_generation"] = [
"BertGenerationDecoder",
"BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation",
]
......@@ -46,6 +47,7 @@ if TYPE_CHECKING:
from .modeling_bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation,
)
......
......@@ -37,7 +37,11 @@ if is_torch_available():
if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
_import_structure["modeling_tf_blenderbot"] = [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -54,7 +58,11 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from .modeling_tf_blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
else:
import importlib
......
......@@ -38,6 +38,7 @@ if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -54,7 +55,11 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
from .modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
else:
import importlib
......
......@@ -52,7 +52,9 @@ if is_flax_available():
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
]
......@@ -77,7 +79,14 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
else:
......
......@@ -46,6 +46,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification",
"TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel",
]
......@@ -74,6 +75,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification,
TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel,
)
......
......@@ -41,6 +41,7 @@ if is_torch_available():
"FunnelForSequenceClassification",
"FunnelForTokenClassification",
"FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel",
]
......@@ -55,6 +56,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification",
"TFFunnelModel",
"TFFunnelPreTrainedModel",
]
......@@ -76,6 +78,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification,
FunnelForTokenClassification,
FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel,
)
......@@ -90,6 +93,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification,
TFFunnelForTokenClassification,
TFFunnelModel,
TFFunnelPreTrainedModel,
)
else:
......
......@@ -58,7 +58,7 @@ if is_tf_available():
]
if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
......@@ -90,7 +90,7 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else:
import importlib
......
......@@ -38,6 +38,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification",
"LayoutLMModel",
"LayoutLMPreTrainedModel",
]
if is_tf_available():
......@@ -66,6 +67,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification,
LayoutLMForTokenClassification,
LayoutLMModel,
LayoutLMPreTrainedModel,
)
if is_tf_available():
from .modeling_tf_layoutlm import (
......
......@@ -38,6 +38,7 @@ if is_torch_available():
"LongformerForSequenceClassification",
"LongformerForTokenClassification",
"LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention",
]
......@@ -50,6 +51,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification",
"TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention",
]
......@@ -70,6 +72,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention,
)
......@@ -82,6 +85,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention,
)
......
......@@ -43,7 +43,7 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
if TYPE_CHECKING:
......@@ -62,7 +62,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
else:
import importlib
......
......@@ -50,7 +50,11 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
_import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -76,7 +80,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
else:
import importlib
......
......@@ -36,6 +36,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification",
"MegatronBertModel",
"MegatronBertPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -53,6 +54,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification,
MegatronBertForTokenClassification,
MegatronBertModel,
MegatronBertPreTrainedModel,
)
else:
......
......@@ -46,7 +46,11 @@ if is_torch_available():
]
if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
_import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
if TYPE_CHECKING:
......@@ -68,7 +72,7 @@ if TYPE_CHECKING:
)
if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
else:
import importlib
......
......@@ -28,10 +28,20 @@ _import_structure = {
}
if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
_import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
_import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING:
......@@ -40,10 +50,15 @@ if TYPE_CHECKING:
from .tokenization_rag import RagTokenizer
if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else:
import importlib
......
......@@ -41,6 +41,7 @@ if is_torch_available():
"ReformerLayer",
"ReformerModel",
"ReformerModelWithLMHead",
"ReformerPreTrainedModel",
]
......@@ -63,6 +64,7 @@ if TYPE_CHECKING:
ReformerLayer,
ReformerModel,
ReformerModelWithLMHead,
ReformerPreTrainedModel,
)
else:
......
......@@ -45,6 +45,7 @@ if is_torch_available():
"RobertaForSequenceClassification",
"RobertaForTokenClassification",
"RobertaModel",
"RobertaPreTrainedModel",
]
if is_tf_available():
......@@ -89,6 +90,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaModel,
RobertaPreTrainedModel,
)
if is_tf_available():
......
......@@ -33,6 +33,7 @@ if is_torch_available():
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
"TapasPreTrainedModel",
]
......@@ -47,6 +48,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering,
TapasForSequenceClassification,
TapasModel,
TapasPreTrainedModel,
)
else:
......
......@@ -37,7 +37,11 @@ if is_torch_available():
if is_flax_available():
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"]
_import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
"FlaxViTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
......@@ -54,7 +58,7 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment