"tests/vscode:/vscode.git/clone" did not exist on "89b00eef9428111cd7c48426905fe013d7047a3b"
Unverified Commit 758ed333 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Transformers fast import part 2 (#9446)



* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>
parent a400fe89
This diff is collapsed.
......@@ -30,9 +30,8 @@ from multiprocessing import Pipe, Process, Queue
from multiprocessing.connection import Connection
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
from transformers import AutoConfig, PretrainedConfig
from transformers import __version__ as version
from .. import AutoConfig, PretrainedConfig
from .. import __version__ as version
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments
......
......@@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
from . import BaseTransformersCLICommand
try:
......
......@@ -14,9 +14,8 @@
from argparse import ArgumentParser, Namespace
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
from . import BaseTransformersCLICommand
def convert_command_factory(args: Namespace):
......@@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand):
def run(self):
if self._model_type == "albert":
try:
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
......@@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "bert":
try:
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
......@@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "funnel":
try:
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
......@@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt":
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
convert_openai_checkpoint_to_pytorch,
)
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "transfo_xl":
try:
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
convert_transfo_xl_checkpoint_to_pytorch,
)
except ImportError:
......@@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand):
)
elif self._model_type == "gpt2":
try:
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
convert_gpt2_checkpoint_to_pytorch,
)
except ImportError:
......@@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "xlnet":
try:
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
convert_xlnet_checkpoint_to_pytorch,
)
except ImportError:
......@@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand):
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
)
elif self._model_type == "xlm":
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
convert_xlm_checkpoint_to_pytorch,
)
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
elif self._model_type == "lxmert":
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
convert_lxmert_checkpoint_to_pytorch,
)
......
......@@ -14,7 +14,7 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from . import BaseTransformersCLICommand
def download_command_factory(args):
......@@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand):
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
from ..models.auto import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
......@@ -15,9 +15,9 @@
import platform
from argparse import ArgumentParser
from transformers import __version__ as version
from transformers import is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand
from .. import __version__ as version
from ..file_utils import is_tf_available, is_torch_available
from . import BaseTransformersCLICommand
def info_command_factory(_):
......
......@@ -25,9 +25,9 @@ from contextlib import AbstractContextManager
from typing import Dict, List, Optional
import requests
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -14,10 +14,9 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -15,11 +15,9 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
try:
......
......@@ -15,11 +15,11 @@
import os
from argparse import ArgumentParser, Namespace
from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand
from ..data import SingleSentenceClassificationProcessor as Processor
from ..file_utils import is_tf_available, is_torch_available
from ..pipelines import TextClassificationPipeline
from ..utils import logging
from . import BaseTransformersCLICommand
if not is_tf_available() and not is_torch_available():
......
......@@ -15,14 +15,14 @@
from argparse import ArgumentParser
from transformers.commands.add_new_model import AddNewModelCommand
from transformers.commands.convert import ConvertCommand
from transformers.commands.download import DownloadCommand
from transformers.commands.env import EnvironmentCommand
from transformers.commands.lfs import LfsCommands
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from .add_new_model import AddNewModelCommand
from .convert import ConvertCommand
from .download import DownloadCommand
from .env import EnvironmentCommand
from .lfs import LfsCommands
from .run import RunCommand
from .serving import ServeCommand
from .user import UserCommands
def main():
......
......@@ -20,8 +20,9 @@ from getpass import getpass
from typing import List, Union
from requests.exceptions import HTTPError
from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder
from ..hf_api import HfApi, HfFolder
from . import BaseTransformersCLICommand
UPLOAD_MAX_FILES = 15
......
......@@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple
from packaging.version import Version, parse
from transformers import is_tf_available, is_torch_available
from transformers.file_utils import ModelOutput
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
from .file_utils import ModelOutput, is_tf_available, is_torch_available
from .pipelines import Pipeline, pipeline
from .tokenization_utils import BatchEncoding
# This is the minimal required version to
......
......@@ -18,7 +18,7 @@
import argparse
import os
from transformers import (
from . import (
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -87,15 +87,15 @@ from transformers import (
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
from transformers.file_utils import hf_bucket_url
from transformers.utils import logging
from .file_utils import hf_bucket_url
from .utils import logging
if is_torch_available():
import numpy as np
import torch
from transformers import (
from . import (
AlbertForPreTraining,
BartForConditionalGeneration,
BertForPreTraining,
......
......@@ -18,8 +18,9 @@ import argparse
import os
import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from .utils import logging
logging.set_verbosity_info()
......
......@@ -17,7 +17,7 @@
import argparse
from transformers import (
from . import (
BertConfig,
BertGenerationConfig,
BertGenerationDecoder,
......
......@@ -27,8 +27,7 @@ import math
import re
import string
from transformers import BasicTokenizer
from ...models.bert import BasicTokenizer
from ...utils import logging
......
......@@ -17,15 +17,14 @@ import unittest
import timeout_decorator
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
from ..file_utils import cached_property, is_torch_available
from ..testing_utils import require_torch
if is_torch_available():
import torch
from transformers import MarianConfig, MarianMTModel
from ..models.marian import MarianConfig, MarianMTModel
@require_torch
......
......@@ -33,6 +33,7 @@ from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
......@@ -41,7 +42,6 @@ import numpy as np
from packaging import version
from tqdm.auto import tqdm
import importlib_metadata
import requests
from filelock import FileLock
......@@ -50,6 +50,13 @@ from .hf_api import HfFolder
from .utils import logging
# The package importlib_metadata is in a different place, depending on the python version.
if version.parse(sys.version) < version.parse("3.8"):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
......@@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError:
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_scatter_version = importlib_metadata.version("torch_scatterr")
_scatter_version = importlib_metadata.version("torch_scatter")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError:
_scatter_available = False
......@@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict):
Convert self to a tuple containing all the attributes/keys that are not ``None``.
"""
return tuple(self[k] for k in self.keys())
class _BaseLazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, import_structure):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
# Needed for autocompletion in an IDE
def __dir__(self):
return super().__dir__() + self.__all__
def __getattr__(self, name: str) -> Any:
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError
......@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
# comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet:
try:
import comet_ml # noqa: F401
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment