Unverified Commit 0a6b9048 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Init pickle (#12567)

* Try to pickle transformers

* Deal with special objs better

* Make picklable
parent b29c3945
......@@ -51,4 +51,4 @@ Special Properties
Other Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.file_utils._BaseLazyModule
.. autoclass:: transformers.file_utils._LazyModule
......@@ -42,7 +42,7 @@ from typing import TYPE_CHECKING
# Check the dependencies satisfy the minimal versions required.
from . import dependency_versions_check
from .file_utils import (
_BaseLazyModule,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_speech_available,
......@@ -3058,28 +3058,11 @@ if TYPE_CHECKING:
from .utils.dummy_flax_objects import *
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
def __getattr__(self, name: str):
# Special handling for the version, which is a constant from this module and not imported in a submodule.
if name == "__version__":
return __version__
return super().__getattr__(name)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(
__name__, globals()["__file__"], _import_structure, extra_objects={"__version__": __version__}
)
if not is_tf_available() and not is_torch_available() and not is_flax_available():
......
......@@ -1865,14 +1865,14 @@ class TensorType(ExplicitEnum):
JAX = "jax"
class _BaseLazyModule(ModuleType):
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, import_structure):
def __init__(self, name, module_file, import_structure, extra_objects=None):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
......@@ -1881,12 +1881,19 @@ class _BaseLazyModule(ModuleType):
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
self.__file__ = module_file
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
# Needed for autocompletion in an IDE
def __dir__(self):
return super().__dir__() + self.__all__
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
......@@ -1898,8 +1905,11 @@ class _BaseLazyModule(ModuleType):
setattr(self, name, value)
return value
def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
def __reduce__(self):
return (self.__class__, (self._name, self._import_structure))
def copy_func(f):
......
......@@ -19,7 +19,7 @@
from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
......@@ -104,19 +104,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -200,19 +200,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -17,13 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = {
......@@ -90,19 +84,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_tokenizers_available
from ...file_utils import _LazyModule, is_sentencepiece_available, is_tokenizers_available
_import_structure = {}
......@@ -39,19 +39,6 @@ if TYPE_CHECKING:
from .tokenization_barthez_fast import BarthezTokenizerFast
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,13 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = {
......@@ -137,19 +131,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available
from ...file_utils import _LazyModule, is_sentencepiece_available, is_torch_available
_import_structure = {
......@@ -52,19 +52,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule
from ...file_utils import _LazyModule
_import_structure = {
......@@ -30,19 +30,6 @@ if TYPE_CHECKING:
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule
from ...file_utils import _LazyModule
_import_structure = {
......@@ -30,19 +30,6 @@ if TYPE_CHECKING:
from .tokenization_bertweet import BertweetTokenizer
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
......@@ -103,19 +103,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_torch_available
from ...file_utils import _LazyModule, is_torch_available
_import_structure = {
......@@ -52,19 +52,6 @@ if TYPE_CHECKING:
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_torch_available
_import_structure = {
......@@ -65,19 +65,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_torch_available
_import_structure = {
......@@ -62,19 +62,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule
from ...file_utils import _LazyModule
_import_structure = {
......@@ -29,19 +29,6 @@ _import_structure = {
if TYPE_CHECKING:
from .tokenization_byt5 import ByT5Tokenizer
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -19,7 +19,7 @@
from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
_LazyModule,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
......@@ -94,19 +94,6 @@ if TYPE_CHECKING:
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
......@@ -58,19 +58,6 @@ if TYPE_CHECKING:
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
_LazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
......@@ -90,19 +90,6 @@ if TYPE_CHECKING:
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
......@@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
_import_structure = {
......@@ -93,19 +93,6 @@ if TYPE_CHECKING:
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment