"docs/vscode:/vscode.git/clone" did not exist on "2d02f7b29b8c18a68260378246f9e26621021c30"
Unverified Commit 1bdf4240 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fast imports part 3 (#9474)

* New intermediate inits

* Update template

* Avoid importing torch/tf/flax in tokenization unless necessary

* Styling

* Shutup flake8

* Better python version check
parent 79bbcc52
...@@ -16,15 +16,62 @@ ...@@ -16,15 +16,62 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .tokenization_roberta import RobertaTokenizer
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig"],
"tokenization_roberta": ["RobertaTokenizer"],
}
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_roberta_fast import RobertaTokenizerFast _import_structure["tokenization_roberta_fast"] = ["RobertaTokenizerFast"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_roberta"] = [
"ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"RobertaForCausalLM",
"RobertaForMaskedLM",
"RobertaForMultipleChoice",
"RobertaForQuestionAnswering",
"RobertaForSequenceClassification",
"RobertaForTokenClassification",
"RobertaModel",
]
if is_tf_available():
_import_structure["modeling_tf_roberta"] = [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRobertaForMaskedLM",
"TFRobertaForMultipleChoice",
"TFRobertaForQuestionAnswering",
"TFRobertaForSequenceClassification",
"TFRobertaForTokenClassification",
"TFRobertaMainLayer",
"TFRobertaModel",
"TFRobertaPreTrainedModel",
]
if is_flax_available():
_import_structure["modeling_flax_roberta"] = ["FlaxRobertaModel"]
if TYPE_CHECKING:
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .tokenization_roberta import RobertaTokenizer
if is_tokenizers_available():
from .tokenization_roberta_fast import RobertaTokenizerFast
if is_torch_available():
from .modeling_roberta import ( from .modeling_roberta import (
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
RobertaForCausalLM, RobertaForCausalLM,
...@@ -36,7 +83,7 @@ if is_torch_available(): ...@@ -36,7 +83,7 @@ if is_torch_available():
RobertaModel, RobertaModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_roberta import ( from .modeling_tf_roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
...@@ -49,5 +96,23 @@ if is_tf_available(): ...@@ -49,5 +96,23 @@ if is_tf_available():
TFRobertaPreTrainedModel, TFRobertaPreTrainedModel,
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_roberta import FlaxRobertaModel from .modeling_flax_roberta import FlaxRobertaModel
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,15 +16,41 @@ ...@@ -16,15 +16,41 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_tokenizers_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from .tokenization_squeezebert import SqueezeBertTokenizer
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig"],
"tokenization_squeezebert": ["SqueezeBertTokenizer"],
}
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_squeezebert_fast import SqueezeBertTokenizerFast _import_structure["tokenization_squeezebert_fast"] = ["SqueezeBertTokenizerFast"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_squeezebert"] = [
"SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"SqueezeBertForMaskedLM",
"SqueezeBertForMultipleChoice",
"SqueezeBertForQuestionAnswering",
"SqueezeBertForSequenceClassification",
"SqueezeBertForTokenClassification",
"SqueezeBertModel",
"SqueezeBertModule",
"SqueezeBertPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from .tokenization_squeezebert import SqueezeBertTokenizer
if is_tokenizers_available():
from .tokenization_squeezebert_fast import SqueezeBertTokenizerFast
if is_torch_available():
from .modeling_squeezebert import ( from .modeling_squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
SqueezeBertForMaskedLM, SqueezeBertForMaskedLM,
...@@ -36,3 +62,21 @@ if is_torch_available(): ...@@ -36,3 +62,21 @@ if is_torch_available():
SqueezeBertModule, SqueezeBertModule,
SqueezeBertPreTrainedModel, SqueezeBertPreTrainedModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,17 +16,57 @@ ...@@ -16,17 +16,57 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ...file_utils import (
_BaseLazyModule,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
}
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_t5 import T5Tokenizer _import_structure["tokenization_t5"] = ["T5Tokenizer"]
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_t5_fast import T5TokenizerFast _import_structure["tokenization_t5_fast"] = ["T5TokenizerFast"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_t5"] = [
"T5_PRETRAINED_MODEL_ARCHIVE_LIST",
"T5EncoderModel",
"T5ForConditionalGeneration",
"T5Model",
"T5PreTrainedModel",
"load_tf_weights_in_t5",
]
if is_tf_available():
_import_structure["modeling_tf_t5"] = [
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFT5EncoderModel",
"TFT5ForConditionalGeneration",
"TFT5Model",
"TFT5PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
if is_sentencepiece_available():
from .tokenization_t5 import T5Tokenizer
if is_tokenizers_available():
from .tokenization_t5_fast import T5TokenizerFast
if is_torch_available():
from .modeling_t5 import ( from .modeling_t5 import (
T5_PRETRAINED_MODEL_ARCHIVE_LIST, T5_PRETRAINED_MODEL_ARCHIVE_LIST,
T5EncoderModel, T5EncoderModel,
...@@ -36,7 +76,7 @@ if is_torch_available(): ...@@ -36,7 +76,7 @@ if is_torch_available():
load_tf_weights_in_t5, load_tf_weights_in_t5,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_t5 import ( from .modeling_tf_t5 import (
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST, TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
TFT5EncoderModel, TFT5EncoderModel,
...@@ -44,3 +84,21 @@ if is_tf_available(): ...@@ -44,3 +84,21 @@ if is_tf_available():
TFT5Model, TFT5Model,
TFT5PreTrainedModel, TFT5PreTrainedModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,12 +16,31 @@ ...@@ -16,12 +16,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_torch_available from typing import TYPE_CHECKING
from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from .tokenization_tapas import TapasTokenizer
from ...file_utils import _BaseLazyModule, is_torch_available
_import_structure = {
"configuration_tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig"],
"tokenization_tapas": ["TapasTokenizer"],
}
if is_torch_available(): if is_torch_available():
_import_structure["modeling_tapas"] = [
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
"TapasForMaskedLM",
"TapasForQuestionAnswering",
"TapasForSequenceClassification",
"TapasModel",
]
if TYPE_CHECKING:
from .configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from .tokenization_tapas import TapasTokenizer
if is_torch_available():
from .modeling_tapas import ( from .modeling_tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM, TapasForMaskedLM,
...@@ -29,3 +48,21 @@ if is_torch_available(): ...@@ -29,3 +48,21 @@ if is_torch_available():
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,12 +16,44 @@ ...@@ -16,12 +16,44 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_tf_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = {
"configuration_transfo_xl": ["TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP", "TransfoXLConfig"],
"tokenization_transfo_xl": ["TransfoXLCorpus", "TransfoXLTokenizer"],
}
if is_torch_available(): if is_torch_available():
_import_structure["modeling_transfo_xl"] = [
"TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"AdaptiveEmbedding",
"TransfoXLForSequenceClassification",
"TransfoXLLMHeadModel",
"TransfoXLModel",
"TransfoXLPreTrainedModel",
"load_tf_weights_in_transfo_xl",
]
if is_tf_available():
_import_structure["modeling_tf_transfo_xl"] = [
"TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFAdaptiveEmbedding",
"TFTransfoXLForSequenceClassification",
"TFTransfoXLLMHeadModel",
"TFTransfoXLMainLayer",
"TFTransfoXLModel",
"TFTransfoXLPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
if is_torch_available():
from .modeling_transfo_xl import ( from .modeling_transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
AdaptiveEmbedding, AdaptiveEmbedding,
...@@ -32,7 +64,7 @@ if is_torch_available(): ...@@ -32,7 +64,7 @@ if is_torch_available():
load_tf_weights_in_transfo_xl, load_tf_weights_in_transfo_xl,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_transfo_xl import ( from .modeling_tf_transfo_xl import (
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFAdaptiveEmbedding, TFAdaptiveEmbedding,
...@@ -42,3 +74,21 @@ if is_tf_available(): ...@@ -42,3 +74,21 @@ if is_tf_available():
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLPreTrainedModel, TFTransfoXLPreTrainedModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,12 +16,48 @@ ...@@ -16,12 +16,48 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_tf_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .tokenization_xlm import XLMTokenizer
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = {
"configuration_xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig"],
"tokenization_xlm": ["XLMTokenizer"],
}
if is_torch_available(): if is_torch_available():
_import_structure["modeling_xlm"] = [
"XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMForMultipleChoice",
"XLMForQuestionAnswering",
"XLMForQuestionAnsweringSimple",
"XLMForSequenceClassification",
"XLMForTokenClassification",
"XLMModel",
"XLMPreTrainedModel",
"XLMWithLMHeadModel",
]
if is_tf_available():
_import_structure["modeling_tf_xlm"] = [
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMForMultipleChoice",
"TFXLMForQuestionAnsweringSimple",
"TFXLMForSequenceClassification",
"TFXLMForTokenClassification",
"TFXLMMainLayer",
"TFXLMModel",
"TFXLMPreTrainedModel",
"TFXLMWithLMHeadModel",
]
if TYPE_CHECKING:
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .tokenization_xlm import XLMTokenizer
if is_torch_available():
from .modeling_xlm import ( from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMForMultipleChoice, XLMForMultipleChoice,
...@@ -34,7 +70,7 @@ if is_torch_available(): ...@@ -34,7 +70,7 @@ if is_torch_available():
XLMWithLMHeadModel, XLMWithLMHeadModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_xlm import ( from .modeling_tf_xlm import (
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMForMultipleChoice, TFXLMForMultipleChoice,
...@@ -46,3 +82,21 @@ if is_tf_available(): ...@@ -46,3 +82,21 @@ if is_tf_available():
TFXLMPreTrainedModel, TFXLMPreTrainedModel,
TFXLMWithLMHeadModel, TFXLMWithLMHeadModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,17 +16,61 @@ ...@@ -16,17 +16,61 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
from ...file_utils import (
_BaseLazyModule,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
}
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_xlm_roberta import XLMRobertaTokenizer _import_structure["tokenization_xlm_roberta"] = ["XLMRobertaTokenizer"]
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast _import_structure["tokenization_xlm_roberta_fast"] = ["XLMRobertaTokenizerFast"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_xlm_roberta"] = [
"XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLMRobertaForCausalLM",
"XLMRobertaForMaskedLM",
"XLMRobertaForMultipleChoice",
"XLMRobertaForQuestionAnswering",
"XLMRobertaForSequenceClassification",
"XLMRobertaForTokenClassification",
"XLMRobertaModel",
]
if is_tf_available():
_import_structure["modeling_tf_xlm_roberta"] = [
"TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLMRobertaForMaskedLM",
"TFXLMRobertaForMultipleChoice",
"TFXLMRobertaForQuestionAnswering",
"TFXLMRobertaForSequenceClassification",
"TFXLMRobertaForTokenClassification",
"TFXLMRobertaModel",
]
if TYPE_CHECKING:
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
if is_sentencepiece_available():
from .tokenization_xlm_roberta import XLMRobertaTokenizer
if is_tokenizers_available():
from .tokenization_xlm_roberta_fast import XLMRobertaTokenizerFast
if is_torch_available():
from .modeling_xlm_roberta import ( from .modeling_xlm_roberta import (
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
XLMRobertaForCausalLM, XLMRobertaForCausalLM,
...@@ -38,7 +82,7 @@ if is_torch_available(): ...@@ -38,7 +82,7 @@ if is_torch_available():
XLMRobertaModel, XLMRobertaModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_xlm_roberta import ( from .modeling_tf_xlm_roberta import (
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLMRobertaForMaskedLM, TFXLMRobertaForMaskedLM,
...@@ -48,3 +92,21 @@ if is_tf_available(): ...@@ -48,3 +92,21 @@ if is_tf_available():
TFXLMRobertaForTokenClassification, TFXLMRobertaForTokenClassification,
TFXLMRobertaModel, TFXLMRobertaModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -16,17 +16,65 @@ ...@@ -16,17 +16,65 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ...file_utils import is_sentencepiece_available, is_tf_available, is_tokenizers_available, is_torch_available from typing import TYPE_CHECKING
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
from ...file_utils import (
_BaseLazyModule,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = {
"configuration_xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"],
}
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_xlnet import XLNetTokenizer _import_structure["tokenization_xlnet"] = ["XLNetTokenizer"]
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_xlnet_fast import XLNetTokenizerFast _import_structure["tokenization_xlnet_fast"] = ["XLNetTokenizerFast"]
if is_torch_available(): if is_torch_available():
_import_structure["modeling_xlnet"] = [
"XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"XLNetForMultipleChoice",
"XLNetForQuestionAnswering",
"XLNetForQuestionAnsweringSimple",
"XLNetForSequenceClassification",
"XLNetForTokenClassification",
"XLNetLMHeadModel",
"XLNetModel",
"XLNetPreTrainedModel",
"load_tf_weights_in_xlnet",
]
if is_tf_available():
_import_structure["modeling_tf_xlnet"] = [
"TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFXLNetForMultipleChoice",
"TFXLNetForQuestionAnsweringSimple",
"TFXLNetForSequenceClassification",
"TFXLNetForTokenClassification",
"TFXLNetLMHeadModel",
"TFXLNetMainLayer",
"TFXLNetModel",
"TFXLNetPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig
if is_sentencepiece_available():
from .tokenization_xlnet import XLNetTokenizer
if is_tokenizers_available():
from .tokenization_xlnet_fast import XLNetTokenizerFast
if is_torch_available():
from .modeling_xlnet import ( from .modeling_xlnet import (
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
XLNetForMultipleChoice, XLNetForMultipleChoice,
...@@ -40,7 +88,7 @@ if is_torch_available(): ...@@ -40,7 +88,7 @@ if is_torch_available():
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_xlnet import ( from .modeling_tf_xlnet import (
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFXLNetForMultipleChoice, TFXLNetForMultipleChoice,
...@@ -52,3 +100,21 @@ if is_tf_available(): ...@@ -52,3 +100,21 @@ if is_tf_available():
TFXLNetModel, TFXLNetModel,
TFXLNetPreTrainedModel, TFXLNetPreTrainedModel,
) )
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
...@@ -25,7 +25,7 @@ import warnings ...@@ -25,7 +25,7 @@ import warnings
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -45,21 +45,34 @@ from .file_utils import ( ...@@ -45,21 +45,34 @@ from .file_utils import (
from .utils import logging from .utils import logging
if is_tf_available(): if TYPE_CHECKING:
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp # noqa: F401
def _is_numpy(x):
return isinstance(x, np.ndarray)
if is_torch_available(): def _is_torch(x):
import torch import torch
if is_flax_available(): return isinstance(x, torch.Tensor)
import jax.numpy as jnp
def _is_numpy(x): def _is_tensorflow(x):
return isinstance(x, np.ndarray) import tensorflow as tf
return isinstance(x, tf.Tensor)
def _is_jax(x): def _is_jax(x):
import jax.numpy as jnp # noqa: F811
return isinstance(x, jnp.ndarray) return isinstance(x, jnp.ndarray)
...@@ -196,9 +209,9 @@ def to_py_obj(obj): ...@@ -196,9 +209,9 @@ def to_py_obj(obj):
return {k: to_py_obj(v) for k, v in obj.items()} return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):
return [to_py_obj(o) for o in obj] return [to_py_obj(o) for o in obj]
elif is_tf_available() and isinstance(obj, tf.Tensor): elif is_tf_available() and _is_tensorflow(obj):
return obj.numpy().tolist() return obj.numpy().tolist()
elif is_torch_available() and isinstance(obj, torch.Tensor): elif is_torch_available() and _is_torch(obj):
return obj.detach().cpu().tolist() return obj.detach().cpu().tolist()
elif isinstance(obj, np.ndarray): elif isinstance(obj, np.ndarray):
return obj.tolist() return obj.tolist()
...@@ -714,16 +727,22 @@ class BatchEncoding(UserDict): ...@@ -714,16 +727,22 @@ class BatchEncoding(UserDict):
raise ImportError( raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed." "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
) )
import tensorflow as tf
as_tensor = tf.constant as_tensor = tf.constant
is_tensor = tf.is_tensor is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH: elif tensor_type == TensorType.PYTORCH:
if not is_torch_available(): if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.") raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
as_tensor = torch.tensor as_tensor = torch.tensor
is_tensor = torch.is_tensor is_tensor = torch.is_tensor
elif tensor_type == TensorType.JAX: elif tensor_type == TensorType.JAX:
if not is_flax_available(): if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.") raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array as_tensor = jnp.array
is_tensor = _is_jax is_tensor = _is_jax
else: else:
...@@ -2684,9 +2703,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -2684,9 +2703,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
first_element = encoded_inputs["input_ids"][index][0] first_element = encoded_inputs["input_ids"][index][0]
# At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do. # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
if not isinstance(first_element, (int, list, tuple)): if not isinstance(first_element, (int, list, tuple)):
if is_tf_available() and isinstance(first_element, tf.Tensor): if is_tf_available() and _is_tensorflow(first_element):
return_tensors = "tf" if return_tensors is None else return_tensors return_tensors = "tf" if return_tensors is None else return_tensors
elif is_torch_available() and isinstance(first_element, torch.Tensor): elif is_torch_available() and _is_torch(first_element):
return_tensors = "pt" if return_tensors is None else return_tensors return_tensors = "pt" if return_tensors is None else return_tensors
elif isinstance(first_element, np.ndarray): elif isinstance(first_element, np.ndarray):
return_tensors = "np" if return_tensors is None else return_tensors return_tensors = "np" if return_tensors is None else return_tensors
......
...@@ -15,23 +15,87 @@ ...@@ -15,23 +15,87 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
{%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %} {%- if cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" %}
from ...file_utils import is_tf_available, is_torch_available, is_tokenizers_available from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available, is_tokenizers_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %} {%- elif cookiecutter.generate_tensorflow_and_pytorch == "PyTorch" %}
from ...file_utils import is_torch_available, is_tokenizers_available from ...file_utils import _BaseLazyModule, is_torch_available, is_tokenizers_available
{%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %} {%- elif cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow" %}
from ...file_utils import is_tf_available, is_tokenizers_available from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available
{% endif %} {% endif %}
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config _import_structure = {
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer "configuration_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config"],
"tokenization_{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.camelcase_modelname}}Tokenizer"],
}
if is_tokenizers_available(): if is_tokenizers_available():
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast _import_structure["tokenization_{{cookiecutter.lowercase_modelname}}_fast"] = ["{{cookiecutter.camelcase_modelname}}TokenizerFast"]
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %} {%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_torch_available(): if is_torch_available():
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"{{cookiecutter.camelcase_modelname}}Layer",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
]
{% else %}
if is_torch_available():
_import_structure["modeling_{{cookiecutter.lowercase_modelname}}"] = [
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% endif %}
{% endif %}
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_tf_available():
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"TF{{cookiecutter.camelcase_modelname}}ForCausalLM",
"TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"TF{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"TF{{cookiecutter.camelcase_modelname}}Layer",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% else %}
if is_tf_available():
_import_structure["modeling_tf_{{cookiecutter.lowercase_modelname}}"] = [
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
{% endif %}
{% endif %}
if TYPE_CHECKING:
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config
from .tokenization_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Tokenizer
if is_tokenizers_available():
from .tokenization_{{cookiecutter.lowercase_modelname}}_fast import {{cookiecutter.camelcase_modelname}}TokenizerFast
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "PyTorch") %}
{% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_torch_available():
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMaskedLM,
...@@ -46,7 +110,7 @@ if is_torch_available(): ...@@ -46,7 +110,7 @@ if is_torch_available():
load_tf_weights_in_{{cookiecutter.lowercase_modelname}}, load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
) )
{% else %} {% else %}
if is_torch_available(): if is_torch_available():
from .modeling_{{cookiecutter.lowercase_modelname}} import ( from .modeling_{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
...@@ -59,7 +123,7 @@ if is_torch_available(): ...@@ -59,7 +123,7 @@ if is_torch_available():
{% endif %} {% endif %}
{%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %} {%- if (cookiecutter.generate_tensorflow_and_pytorch == "PyTorch & TensorFlow" or cookiecutter.generate_tensorflow_and_pytorch == "TensorFlow") %}
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
if is_tf_available(): if is_tf_available():
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
...@@ -73,7 +137,7 @@ if is_tf_available(): ...@@ -73,7 +137,7 @@ if is_tf_available():
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% else %} {% else %}
if is_tf_available(): if is_tf_available():
from .modeling_tf_{{cookiecutter.lowercase_modelname}} import ( from .modeling_tf_{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model, TF{{cookiecutter.camelcase_modelname}}Model,
...@@ -81,3 +145,20 @@ if is_tf_available(): ...@@ -81,3 +145,20 @@ if is_tf_available():
) )
{% endif %} {% endif %}
{% endif %} {% endif %}
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment