Unverified Commit df735d13 authored by Dom Miketa's avatar Dom Miketa Committed by GitHub
Browse files

[WIP] Fix Pyright static type checking by replacing if-else imports with try-except (#16578)



* rebase and isort

* modify cookiecutter init

* fix cookiecutter auto imports

* fix clean_frameworks_in_init

* fix add_model_to_main_init

* blackify

* replace unnecessary f-strings

* update yolos imports

* fix roberta import bug

* fix yolos missing dependency

* fix add_model_like and cookiecutter bug

* fix repository consistency error

* modify cookiecutter, fix add_new_model_like

* remove stale line
Co-authored-by: default avatarDom Miketa <dmiketa@exscientia.co.uk>
parent 7783fa6b
...@@ -17,14 +17,25 @@ ...@@ -17,14 +17,25 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = { _import_structure = {
"configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"], "configuration_vit_mae": ["VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTMAEConfig"],
} }
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_vit_mae"] = [ _import_structure["modeling_vit_mae"] = [
"VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST", "VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTMAEForPreTraining", "ViTMAEForPreTraining",
...@@ -33,7 +44,12 @@ if is_torch_available(): ...@@ -33,7 +44,12 @@ if is_torch_available():
"ViTMAEPreTrainedModel", "ViTMAEPreTrainedModel",
] ]
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_vit_mae"] = [ _import_structure["modeling_tf_vit_mae"] = [
"TFViTMAEForPreTraining", "TFViTMAEForPreTraining",
"TFViTMAEModel", "TFViTMAEModel",
...@@ -43,7 +59,12 @@ if is_tf_available(): ...@@ -43,7 +59,12 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_vit_mae import ( from .modeling_vit_mae import (
VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST, VIT_MAE_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTMAEForPreTraining, ViTMAEForPreTraining,
...@@ -52,7 +73,12 @@ if TYPE_CHECKING: ...@@ -52,7 +73,12 @@ if TYPE_CHECKING:
ViTMAEPreTrainedModel, ViTMAEPreTrainedModel,
) )
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
......
...@@ -17,7 +17,13 @@ ...@@ -17,7 +17,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -28,7 +34,12 @@ _import_structure = { ...@@ -28,7 +34,12 @@ _import_structure = {
} }
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_wav2vec2"] = [ _import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForAudioFrameClassification", "Wav2Vec2ForAudioFrameClassification",
...@@ -41,7 +52,12 @@ if is_torch_available(): ...@@ -41,7 +52,12 @@ if is_torch_available():
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_wav2vec2"] = [ _import_structure["modeling_tf_wav2vec2"] = [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFWav2Vec2ForCTC", "TFWav2Vec2ForCTC",
...@@ -49,7 +65,12 @@ if is_tf_available(): ...@@ -49,7 +65,12 @@ if is_tf_available():
"TFWav2Vec2PreTrainedModel", "TFWav2Vec2PreTrainedModel",
] ]
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_wav2vec2"] = [ _import_structure["modeling_flax_wav2vec2"] = [
"FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForCTC",
"FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2ForPreTraining",
...@@ -64,7 +85,12 @@ if TYPE_CHECKING: ...@@ -64,7 +85,12 @@ if TYPE_CHECKING:
from .processing_wav2vec2 import Wav2Vec2Processor from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_wav2vec2 import ( from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForAudioFrameClassification, Wav2Vec2ForAudioFrameClassification,
...@@ -77,7 +103,12 @@ if TYPE_CHECKING: ...@@ -77,7 +103,12 @@ if TYPE_CHECKING:
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
) )
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_wav2vec2 import ( from .modeling_tf_wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
...@@ -85,7 +116,12 @@ if TYPE_CHECKING: ...@@ -85,7 +116,12 @@ if TYPE_CHECKING:
TFWav2Vec2PreTrainedModel, TFWav2Vec2PreTrainedModel,
) )
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_wav2vec2 import ( from .modeling_tf_wav2vec2 import (
FlaxWav2Vec2ForCTC, FlaxWav2Vec2ForCTC,
FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2ForPreTraining,
......
...@@ -17,14 +17,19 @@ ...@@ -17,14 +17,19 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = { _import_structure = {
"configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"], "configuration_wavlm": ["WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "WavLMConfig"],
} }
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_wavlm"] = [ _import_structure["modeling_wavlm"] = [
"WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST", "WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"WavLMForAudioFrameClassification", "WavLMForAudioFrameClassification",
...@@ -38,7 +43,12 @@ if is_torch_available(): ...@@ -38,7 +43,12 @@ if is_torch_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig from .configuration_wavlm import WAVLM_PRETRAINED_CONFIG_ARCHIVE_MAP, WavLMConfig
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_wavlm import ( from .modeling_wavlm import (
WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST, WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST,
WavLMForAudioFrameClassification, WavLMForAudioFrameClassification,
......
...@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING ...@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING
# rely on isort to merge the imports # rely on isort to merge the imports
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
...@@ -31,13 +32,28 @@ _import_structure = { ...@@ -31,13 +32,28 @@ _import_structure = {
"configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"], "configuration_xglm": ["XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XGLMConfig"],
} }
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xglm"] = ["XGLMTokenizer"] _import_structure["tokenization_xglm"] = ["XGLMTokenizer"]
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xglm_fast"] = ["XGLMTokenizerFast"] _import_structure["tokenization_xglm_fast"] = ["XGLMTokenizerFast"]
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_xglm"] = [ _import_structure["modeling_xglm"] = [
"XGLM_PRETRAINED_MODEL_ARCHIVE_LIST", "XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XGLMForCausalLM", "XGLMForCausalLM",
...@@ -46,7 +62,12 @@ if is_torch_available(): ...@@ -46,7 +62,12 @@ if is_torch_available():
] ]
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_xglm"] = [ _import_structure["modeling_flax_xglm"] = [
"FlaxXGLMForCausalLM", "FlaxXGLMForCausalLM",
"FlaxXGLMModel", "FlaxXGLMModel",
...@@ -57,16 +78,36 @@ if is_flax_available(): ...@@ -57,16 +78,36 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig from .configuration_xglm import XGLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XGLMConfig
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xglm import XGLMTokenizer from .tokenization_xglm import XGLMTokenizer
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xglm_fast import XGLMTokenizerFast from .tokenization_xglm_fast import XGLMTokenizerFast
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel from .modeling_xglm import XGLM_PRETRAINED_MODEL_ARCHIVE_LIST, XGLMForCausalLM, XGLMModel, XGLMPreTrainedModel
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .modeling_flax_xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_tf_available, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -26,7 +26,12 @@ _import_structure = { ...@@ -26,7 +26,12 @@ _import_structure = {
"tokenization_xlm": ["XLMTokenizer"], "tokenization_xlm": ["XLMTokenizer"],
} }
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_xlm"] = [ _import_structure["modeling_xlm"] = [
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST", "XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMForMultipleChoice", "XLMForMultipleChoice",
...@@ -39,7 +44,12 @@ if is_torch_available(): ...@@ -39,7 +44,12 @@ if is_torch_available():
"XLMWithLMHeadModel", "XLMWithLMHeadModel",
] ]
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_xlm"] = [ _import_structure["modeling_tf_xlm"] = [
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMForMultipleChoice", "TFXLMForMultipleChoice",
...@@ -57,7 +67,12 @@ if TYPE_CHECKING: ...@@ -57,7 +67,12 @@ if TYPE_CHECKING:
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .tokenization_xlm import XLMTokenizer from .tokenization_xlm import XLMTokenizer
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_xlm import ( from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice, XLMForMultipleChoice,
...@@ -70,7 +85,12 @@ if TYPE_CHECKING: ...@@ -70,7 +85,12 @@ if TYPE_CHECKING:
XLMWithLMHeadModel, XLMWithLMHeadModel,
) )
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_sentencepiece_available, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_sentencepiece_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -27,10 +27,20 @@ _import_structure = { ...@@ -27,10 +27,20 @@ _import_structure = {
], ],
} }
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xlm_prophetnet"] = ["XLMProphetNetTokenizer"] _import_structure["tokenization_xlm_prophetnet"] = ["XLMProphetNetTokenizer"]
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_xlm_prophetnet"] = [ _import_structure["modeling_xlm_prophetnet"] = [
"XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", "XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMProphetNetDecoder", "XLMProphetNetDecoder",
...@@ -44,10 +54,20 @@ if is_torch_available(): ...@@ -44,10 +54,20 @@ if is_torch_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig from .configuration_xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer from .tokenization_xlm_prophetnet import XLMProphetNetTokenizer
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_xlm_prophetnet import ( from .modeling_xlm_prophetnet import (
XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMProphetNetDecoder, XLMProphetNetDecoder,
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
...@@ -36,13 +37,28 @@ _import_structure = { ...@@ -36,13 +37,28 @@ _import_structure = {
], ],
} }
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xlm_roberta"] = ["XLMRobertaTokenizer"] _import_structure["tokenization_xlm_roberta"] = ["XLMRobertaTokenizer"]
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xlm_roberta_fast"] = ["XLMRobertaTokenizerFast"] _import_structure["tokenization_xlm_roberta_fast"] = ["XLMRobertaTokenizerFast"]
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_xlm_roberta"] = [ _import_structure["modeling_xlm_roberta"] = [
"XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaForCausalLM", "XLMRobertaForCausalLM",
...@@ -54,7 +70,12 @@ if is_torch_available(): ...@@ -54,7 +70,12 @@ if is_torch_available():
"XLMRobertaModel", "XLMRobertaModel",
] ]
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_xlm_roberta"] = [ _import_structure["modeling_tf_xlm_roberta"] = [
"TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMRobertaForMaskedLM", "TFXLMRobertaForMaskedLM",
...@@ -65,7 +86,12 @@ if is_tf_available(): ...@@ -65,7 +86,12 @@ if is_tf_available():
"TFXLMRobertaModel", "TFXLMRobertaModel",
] ]
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_xlm_roberta"] = [ _import_structure["modeling_flax_xlm_roberta"] = [
"FlaxXLMRobertaForMaskedLM", "FlaxXLMRobertaForMaskedLM",
"FlaxXLMRobertaForMultipleChoice", "FlaxXLMRobertaForMultipleChoice",
...@@ -82,13 +108,28 @@ if TYPE_CHECKING: ...@@ -82,13 +108,28 @@ if TYPE_CHECKING:
XLMRobertaOnnxConfig, XLMRobertaOnnxConfig,
) )
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xlm_roberta import XLMRobertaTokenizer from .tokenization_xlm_roberta import XLMRobertaTokenizer
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_xlm_roberta import ( from .modeling_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaForCausalLM, XLMRobertaForCausalLM,
...@@ -100,7 +141,12 @@ if TYPE_CHECKING: ...@@ -100,7 +141,12 @@ if TYPE_CHECKING:
XLMRobertaModel, XLMRobertaModel,
) )
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_xlm_roberta import ( from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM, TFXLMRobertaForMaskedLM,
...@@ -111,7 +157,12 @@ if TYPE_CHECKING: ...@@ -111,7 +157,12 @@ if TYPE_CHECKING:
TFXLMRobertaModel, TFXLMRobertaModel,
) )
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_xlm_roberta import ( from .modeling_flax_xlm_roberta import (
FlaxXLMRobertaForMaskedLM, FlaxXLMRobertaForMaskedLM,
FlaxXLMRobertaForMultipleChoice, FlaxXLMRobertaForMultipleChoice,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = { _import_structure = {
...@@ -29,7 +29,12 @@ _import_structure = { ...@@ -29,7 +29,12 @@ _import_structure = {
], ],
} }
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_xlm_roberta_xl"] = [ _import_structure["modeling_xlm_roberta_xl"] = [
"XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST", "XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaXLForCausalLM", "XLMRobertaXLForCausalLM",
...@@ -49,7 +54,12 @@ if TYPE_CHECKING: ...@@ -49,7 +54,12 @@ if TYPE_CHECKING:
XLMRobertaXLOnnxConfig, XLMRobertaXLOnnxConfig,
) )
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_xlm_roberta_xl import ( from .modeling_xlm_roberta_xl import (
XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_ROBERTA_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaXLForCausalLM, XLMRobertaXLForCausalLM,
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_sentencepiece_available, is_sentencepiece_available,
is_tf_available, is_tf_available,
...@@ -31,13 +32,28 @@ _import_structure = { ...@@ -31,13 +32,28 @@ _import_structure = {
"configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"], "configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
} }
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xlnet"] = ["XLNetTokenizer"] _import_structure["tokenization_xlnet"] = ["XLNetTokenizer"]
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_xlnet_fast"] = ["XLNetTokenizerFast"] _import_structure["tokenization_xlnet_fast"] = ["XLNetTokenizerFast"]
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_xlnet"] = [ _import_structure["modeling_xlnet"] = [
"XLNET_PRETRAINED_MODEL_ARCHIVE_LIST", "XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLNetForMultipleChoice", "XLNetForMultipleChoice",
...@@ -51,7 +67,12 @@ if is_torch_available(): ...@@ -51,7 +67,12 @@ if is_torch_available():
"load_tf_weights_in_xlnet", "load_tf_weights_in_xlnet",
] ]
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_xlnet"] = [ _import_structure["modeling_tf_xlnet"] = [
"TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLNetForMultipleChoice", "TFXLNetForMultipleChoice",
...@@ -68,13 +89,28 @@ if is_tf_available(): ...@@ -68,13 +89,28 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
if is_sentencepiece_available(): try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xlnet import XLNetTokenizer from .tokenization_xlnet import XLNetTokenizer
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_xlnet_fast import XLNetTokenizerFast from .tokenization_xlnet_fast import XLNetTokenizerFast
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_xlnet import ( from .modeling_xlnet import (
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLNetForMultipleChoice, XLNetForMultipleChoice,
...@@ -88,7 +124,12 @@ if TYPE_CHECKING: ...@@ -88,7 +124,12 @@ if TYPE_CHECKING:
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
) )
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_xlnet import ( from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice, TFXLNetForMultipleChoice,
......
...@@ -17,17 +17,27 @@ ...@@ -17,17 +17,27 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import _LazyModule, is_torch_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = { _import_structure = {
"configuration_yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"], "configuration_yolos": ["YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP", "YolosConfig"],
} }
if is_vision_available(): try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_yolos"] = ["YolosFeatureExtractor"] _import_structure["feature_extraction_yolos"] = ["YolosFeatureExtractor"]
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_yolos"] = [ _import_structure["modeling_yolos"] = [
"YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST", "YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST",
"YolosForObjectDetection", "YolosForObjectDetection",
...@@ -39,10 +49,20 @@ if is_torch_available(): ...@@ -39,10 +49,20 @@ if is_torch_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig from .configuration_yolos import YOLOS_PRETRAINED_CONFIG_ARCHIVE_MAP, YolosConfig
if is_vision_available(): try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_yolos import YolosFeatureExtractor from .feature_extraction_yolos import YolosFeatureExtractor
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_yolos import ( from .modeling_yolos import (
YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST, YOLOS_PRETRAINED_MODEL_ARCHIVE_LIST,
YolosForObjectDetection, YolosForObjectDetection,
......
...@@ -18,14 +18,19 @@ ...@@ -18,14 +18,19 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
# rely on isort to merge the imports # rely on isort to merge the imports
from ...utils import _LazyModule, is_tokenizers_available, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = { _import_structure = {
"configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"], "configuration_yoso": ["YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP", "YosoConfig"],
} }
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_yoso"] = [ _import_structure["modeling_yoso"] = [
"YOSO_PRETRAINED_MODEL_ARCHIVE_LIST", "YOSO_PRETRAINED_MODEL_ARCHIVE_LIST",
"YosoForMaskedLM", "YosoForMaskedLM",
...@@ -42,7 +47,12 @@ if is_torch_available(): ...@@ -42,7 +47,12 @@ if is_torch_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig from .configuration_yoso import YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP, YosoConfig
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_yoso import ( from .modeling_yoso import (
YOSO_PRETRAINED_MODEL_ARCHIVE_LIST, YOSO_PRETRAINED_MODEL_ARCHIVE_LIST,
YosoForMaskedLM, YosoForMaskedLM,
......
...@@ -83,6 +83,7 @@ from .import_utils import ( ...@@ -83,6 +83,7 @@ from .import_utils import (
USE_TF, USE_TF,
USE_TORCH, USE_TORCH,
DummyObject, DummyObject,
OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
......
...@@ -866,3 +866,7 @@ class _LazyModule(ModuleType): ...@@ -866,3 +866,7 @@ class _LazyModule(ModuleType):
def __reduce__(self): def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure)) return (self.__class__, (self._name, self.__file__, self._import_structure))
class OptionalDependencyNotAvailable(BaseException):
"""Internally used error class for signalling an optional dependency was not found."""
...@@ -18,15 +18,23 @@ ...@@ -18,15 +18,23 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
# rely on isort to merge the imports # rely on isort to merge the imports
from ...utils import _LazyModule, is_tokenizers_available from ...utils import _LazyModule, OptionalDependencyNotAvailable, is_tokenizers_available
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_tf_available from ...utils import is_tf_available
{% endif %} {% endif %}
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_torch_available from ...utils import is_torch_available
{% endif %} {% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
from ...utils import is_flax_available from ...utils import is_flax_available
{% endif %} {% endif %}
_import_structure = { _import_structure = {
...@@ -34,12 +42,22 @@ _import_structure = { ...@@ -34,12 +42,22 @@ _import_structure = {
"tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"], "tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"],
} }
if is_tokenizers_available(): try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"] _import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"]
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [ _import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM", "{{cookiecutter.camelcase_modelname}}ForMaskedLM",
...@@ -54,7 +72,12 @@ if is_torch_available(): ...@@ -54,7 +72,12 @@ if is_torch_available():
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}", "load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
] ]
{% else %} {% else %}
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [ _import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", "{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
...@@ -70,7 +93,12 @@ if is_torch_available(): ...@@ -70,7 +93,12 @@ if is_torch_available():
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [ _import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM", "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
...@@ -84,7 +112,12 @@ if is_tf_available(): ...@@ -84,7 +112,12 @@ if is_tf_available():
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel", "TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
] ]
{% else %} {% else %}
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [ _import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model", "TF{{cookiecutter.camelcase_modelname}}Model",
...@@ -96,7 +129,12 @@ if is_tf_available(): ...@@ -96,7 +129,12 @@ if is_tf_available():
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [ _import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM", "Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"Flax{{cookiecutter.camelcase_modelname}}ForCausalLM", "Flax{{cookiecutter.camelcase_modelname}}ForCausalLM",
...@@ -109,7 +147,12 @@ if is_flax_available(): ...@@ -109,7 +147,12 @@ if is_flax_available():
"Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel", "Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel",
] ]
{% else %} {% else %}
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [ _import_structure["modeling_flax_{{cookiecutter.lowercase_modelname}}"] = [
"Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", "Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
...@@ -125,12 +168,22 @@ if TYPE_CHECKING: ...@@ -125,12 +168,22 @@ if TYPE_CHECKING:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
if is_tokenizers_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
{%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "PyTorch" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMaskedLM,
...@@ -145,7 +198,12 @@ if TYPE_CHECKING: ...@@ -145,7 +198,12 @@ if TYPE_CHECKING:
load_tf_weights_in_{{cookiecutter.lowercase_modelname}}, load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
) )
{% else %} {% else %}
if is_torch_available(): try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
...@@ -159,7 +217,12 @@ if TYPE_CHECKING: ...@@ -159,7 +217,12 @@ if TYPE_CHECKING:
{% endif %} {% endif %}
{%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "TensorFlow" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
...@@ -173,7 +236,12 @@ if TYPE_CHECKING: ...@@ -173,7 +236,12 @@ if TYPE_CHECKING:
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% else %} {% else %}
if is_tf_available(): try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model, TF{{cookiecutter.camelcase_modelname}}Model,
...@@ -183,7 +251,12 @@ if TYPE_CHECKING: ...@@ -183,7 +251,12 @@ if TYPE_CHECKING:
{% endif %} {% endif %}
{%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %} {%- if "Flax" in cookiecutter.generate_tensorflow_pytorch_and_flax %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM, Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM,
Flax{{cookiecutter.camelcase_modelname}}ForCausalLM, Flax{{cookiecutter.camelcase_modelname}}ForCausalLM,
...@@ -196,7 +269,12 @@ if TYPE_CHECKING: ...@@ -196,7 +269,12 @@ if TYPE_CHECKING:
Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel, Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% else %} {% else %}
if is_flax_available(): try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, Flax{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
......
...@@ -115,7 +115,7 @@ ...@@ -115,7 +115,7 @@
{% endif -%} {% endif -%}
# End. # End.
# Below: " # Fast tokenizers" # Below: " # Fast tokenizers structure"
# Replace with: # Replace with:
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast") _import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
# End. # End.
...@@ -126,7 +126,7 @@ ...@@ -126,7 +126,7 @@
# End. # End.
# To replace in: "src/transformers/__init__.py" # To replace in: "src/transformers/__init__.py"
# Below: " if is_torch_available():" if generating PyTorch # Below: " # PyTorch model imports" if generating PyTorch
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
...@@ -155,7 +155,7 @@ ...@@ -155,7 +155,7 @@
{% endif -%} {% endif -%}
# End. # End.
# Below: " if is_tf_available():" if generating TensorFlow # Below: " # TensorFlow model imports" if generating TensorFlow
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
...@@ -179,7 +179,7 @@ ...@@ -179,7 +179,7 @@
{% endif -%} {% endif -%}
# End. # End.
# Below: " if is_flax_available():" if generating Flax # Below: " # Flax model imports" if generating Flax
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
...@@ -204,7 +204,7 @@ ...@@ -204,7 +204,7 @@
{% endif -%} {% endif -%}
# End. # End.
# Below: " if is_tokenizers_available():" # Below: " # Fast tokenizers imports"
# Replace with: # Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
# End. # End.
......
...@@ -27,8 +27,8 @@ PATH_TO_TRANSFORMERS = "src/transformers" ...@@ -27,8 +27,8 @@ PATH_TO_TRANSFORMERS = "src/transformers"
_re_backend = re.compile(r"is\_([a-z_]*)_available()") _re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Catches a line with a key-values pattern: "bla": ["foo", "bar"] # Catches a line with a key-values pattern: "bla": ["foo", "bar"]
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') _re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if is_foo_available # Catches a line if not is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)") _re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"].append("foo") # Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') _re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] # Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
...@@ -39,6 +39,10 @@ _re_quote_object = re.compile('^\s+"([^"]+)",') ...@@ -39,6 +39,10 @@ _re_quote_object = re.compile('^\s+"([^"]+)",')
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]") _re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
# Catches a line with from foo import bar, bla, boo # Catches a line with from foo import bar, bla, boo
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
# Catches a line with try:
_re_try = re.compile(r"^\s*try:")
# Catches a line with else:
_re_else = re.compile(r"^\s*else:")
def find_backend(line): def find_backend(line):
...@@ -81,11 +85,21 @@ def parse_init(init_file): ...@@ -81,11 +85,21 @@ def parse_init(init_file):
import_dict_objects = {"none": objects} import_dict_objects = {"none": objects}
# Let's continue with backend-specific objects in _import_structure # Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"): while not lines[line_index].startswith("if TYPE_CHECKING"):
# If the line is an if is_backend_available, we grab all objects associated. # If the line is an if not is_backend_available, we grab all objects associated.
backend = find_backend(lines[line_index]) backend = find_backend(lines[line_index])
# Check if the backend declaration is inside a try block:
if _re_try.search(lines[line_index - 1]) is None:
backend = None
if backend is not None: if backend is not None:
line_index += 1 line_index += 1
# Scroll until we hit the else block of try-except-else
while _re_else.search(lines[line_index]) is None:
line_index += 1
line_index += 1
objects = [] objects = []
# Until we unindent, add backend objects to the list # Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
...@@ -132,9 +146,19 @@ def parse_init(init_file): ...@@ -132,9 +146,19 @@ def parse_init(init_file):
while line_index < len(lines): while line_index < len(lines):
# If the line is an if is_backemd_available, we grab all objects associated. # If the line is an if is_backemd_available, we grab all objects associated.
backend = find_backend(lines[line_index]) backend = find_backend(lines[line_index])
# Check if the backend declaration is inside a try block:
if _re_try.search(lines[line_index - 1]) is None:
backend = None
if backend is not None: if backend is not None:
line_index += 1 line_index += 1
# Scroll until we hit the else block of try-except-else
while _re_else.search(lines[line_index]) is None:
line_index += 1
line_index += 1
objects = [] objects = []
# Until we unindent, add backend objects to the list # Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment