Unverified Commit 75e63dbf authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix MT5 init (#12591)

parent 4da568c1
...@@ -29,24 +29,22 @@ from ...file_utils import ( ...@@ -29,24 +29,22 @@ from ...file_utils import (
if is_sentencepiece_available(): if is_sentencepiece_available():
from ..t5.tokenization_t5 import T5Tokenizer from ..t5.tokenization_t5 import T5Tokenizer
else:
from ...utils.dummy_sentencepiece_objects import T5Tokenizer
MT5Tokenizer = T5Tokenizer MT5Tokenizer = T5Tokenizer
if is_tokenizers_available(): if is_tokenizers_available():
from ..t5.tokenization_t5_fast import T5TokenizerFast from ..t5.tokenization_t5_fast import T5TokenizerFast
else:
from ...utils.dummy_tokenizers_objects import T5TokenizerFast
MT5TokenizerFast = T5TokenizerFast MT5TokenizerFast = T5TokenizerFast
_import_structure = { _import_structure = {
"configuration_mt5": ["MT5Config"], "configuration_mt5": ["MT5Config"],
} }
if is_sentencepiece_available():
_import_structure["."] = ["T5Tokenizer"] # Fake to get the same objects in both side.
if is_tokenizers_available():
_import_structure["."] = ["T5TokenizerFast"] # Fake to get the same objects in both side.
if is_torch_available(): if is_torch_available():
_import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"] _import_structure["modeling_mt5"] = ["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]
...@@ -57,16 +55,6 @@ if is_tf_available(): ...@@ -57,16 +55,6 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mt5 import MT5Config from .configuration_mt5 import MT5Config
if is_sentencepiece_available():
from ..t5.tokenization_t5 import T5Tokenizer
MT5Tokenizer = T5Tokenizer
if is_tokenizers_available():
from ..t5.tokenization_t5_fast import T5TokenizerFast
MT5TokenizerFast = T5TokenizerFast
if is_torch_available(): if is_torch_available():
from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model from .modeling_mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model
...@@ -76,20 +64,7 @@ if TYPE_CHECKING: ...@@ -76,20 +64,7 @@ if TYPE_CHECKING:
else: else:
import sys import sys
class _MT5LazyModule(_LazyModule): sys.modules[__name__] = _LazyModule(
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
def __getattr__(self, name):
if name == "MT5Tokenizer":
return MT5Tokenizer
elif name == "MT5TokenizerFast":
return MT5TokenizerFast
else:
return super().__getattr__(name)
sys.modules[__name__] = _MT5LazyModule(
__name__, __name__,
globals()["__file__"], globals()["__file__"],
_import_structure, _import_structure,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment