Unverified Commit 6d4f8bd0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add Flax dummy objects (#7918)

parent 3e31e7f9
...@@ -841,6 +841,11 @@ else: ...@@ -841,6 +841,11 @@ else:
if is_flax_available(): if is_flax_available():
from .modeling_flax_bert import FlaxBertModel from .modeling_flax_bert import FlaxBertModel
from .modeling_flax_roberta import FlaxRobertaModel from .modeling_flax_roberta import FlaxRobertaModel
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
from .utils.dummy_flax_objects import *
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
logger.warning( logger.warning(
......
...@@ -356,6 +356,12 @@ installation page: https://www.tensorflow.org/install and follow the ones that m ...@@ -356,6 +356,12 @@ installation page: https://www.tensorflow.org/install and follow the ones that m
""" """
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your enviromnent. Checkout the instructions on the
installation page: https://github.com/google/flax and follow the ones that match your enviromnent.
"""
def requires_datasets(obj): def requires_datasets(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_datasets_available(): if not is_datasets_available():
...@@ -386,6 +392,12 @@ def requires_tf(obj): ...@@ -386,6 +392,12 @@ def requires_tf(obj):
raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name)) raise ImportError(TENSORFLOW_IMPORT_ERROR.format(name))
def requires_flax(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_flax_available():
raise ImportError(FLAX_IMPORT_ERROR.format(name))
def requires_tokenizers(obj): def requires_tokenizers(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_tokenizers_available(): if not is_tokenizers_available():
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..file_utils import requires_flax
class FlaxBertModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxRobertaModel:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
...@@ -72,6 +72,28 @@ def {0}(*args, **kwargs): ...@@ -72,6 +72,28 @@ def {0}(*args, **kwargs):
""" """
DUMMY_FLAX_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_FUNCTION = """
def {0}(*args, **kwargs):
requires_flax({0})
"""
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """ DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
class {0}: class {0}:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -120,6 +142,7 @@ def {0}(*args, **kwargs): ...@@ -120,6 +142,7 @@ def {0}(*args, **kwargs):
DUMMY_PRETRAINED_CLASS = { DUMMY_PRETRAINED_CLASS = {
"pt": DUMMY_PT_PRETRAINED_CLASS, "pt": DUMMY_PT_PRETRAINED_CLASS,
"tf": DUMMY_TF_PRETRAINED_CLASS, "tf": DUMMY_TF_PRETRAINED_CLASS,
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS, "sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS, "tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
} }
...@@ -127,6 +150,7 @@ DUMMY_PRETRAINED_CLASS = { ...@@ -127,6 +150,7 @@ DUMMY_PRETRAINED_CLASS = {
DUMMY_CLASS = { DUMMY_CLASS = {
"pt": DUMMY_PT_CLASS, "pt": DUMMY_PT_CLASS,
"tf": DUMMY_TF_CLASS, "tf": DUMMY_TF_CLASS,
"flax": DUMMY_FLAX_CLASS,
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS, "sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
"tokenizers": DUMMY_TOKENIZERS_CLASS, "tokenizers": DUMMY_TOKENIZERS_CLASS,
} }
...@@ -134,6 +158,7 @@ DUMMY_CLASS = { ...@@ -134,6 +158,7 @@ DUMMY_CLASS = {
DUMMY_FUNCTION = { DUMMY_FUNCTION = {
"pt": DUMMY_PT_FUNCTION, "pt": DUMMY_PT_FUNCTION,
"tf": DUMMY_TF_FUNCTION, "tf": DUMMY_TF_FUNCTION,
"flax": DUMMY_FLAX_FUNCTION,
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION, "sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
"tokenizers": DUMMY_TOKENIZERS_FUNCTION, "tokenizers": DUMMY_TOKENIZERS_FUNCTION,
} }
...@@ -208,7 +233,24 @@ def read_init(): ...@@ -208,7 +233,24 @@ def read_init():
elif line.startswith(" "): elif line.startswith(" "):
tf_objects.append(line[8:-2]) tf_objects.append(line[8:-2])
line_index += 1 line_index += 1
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects
# Find where the FLAX imports begin
flax_objects = []
while not lines[line_index].startswith("if is_flax_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
flax_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
flax_objects.append(line[8:-2])
line_index += 1
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects
def create_dummy_object(name, type="pt"): def create_dummy_object(name, type="pt"):
...@@ -224,7 +266,7 @@ def create_dummy_object(name, type="pt"): ...@@ -224,7 +266,7 @@ def create_dummy_object(name, type="pt"):
"Model", "Model",
"Tokenizer", "Tokenizer",
] ]
assert type in ["pt", "tf", "sentencepiece", "tokenizers"] assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
if name.isupper(): if name.isupper():
return DUMMY_CONSTANT.format(name) return DUMMY_CONSTANT.format(name)
elif name.islower(): elif name.islower():
...@@ -244,7 +286,7 @@ def create_dummy_object(name, type="pt"): ...@@ -244,7 +286,7 @@ def create_dummy_object(name, type="pt"):
def create_dummy_files(): def create_dummy_files():
""" Create the content of the dummy files. """ """ Create the content of the dummy files. """
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects = read_init() sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init()
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n" sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n"
...@@ -262,17 +304,22 @@ def create_dummy_files(): ...@@ -262,17 +304,22 @@ def create_dummy_files():
tf_dummies += "from ..file_utils import requires_tf\n\n" tf_dummies += "from ..file_utils import requires_tf\n\n"
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects]) tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
flax_dummies += "from ..file_utils import requires_flax\n\n"
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
def check_dummies(overwrite=False): def check_dummies(overwrite=False):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """ """ Check if the dummy files are up to date and maybe `overwrite` with the right content. """
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies = create_dummy_files() sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files()
path = os.path.join(PATH_TO_TRANSFORMERS, "utils") path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py") sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py")
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py") tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
pt_file = os.path.join(path, "dummy_pt_objects.py") pt_file = os.path.join(path, "dummy_pt_objects.py")
tf_file = os.path.join(path, "dummy_tf_objects.py") tf_file = os.path.join(path, "dummy_tf_objects.py")
flax_file = os.path.join(path, "dummy_flax_objects.py")
with open(sentencepiece_file, "r", encoding="utf-8") as f: with open(sentencepiece_file, "r", encoding="utf-8") as f:
actual_sentencepiece_dummies = f.read() actual_sentencepiece_dummies = f.read()
...@@ -282,6 +329,8 @@ def check_dummies(overwrite=False): ...@@ -282,6 +329,8 @@ def check_dummies(overwrite=False):
actual_pt_dummies = f.read() actual_pt_dummies = f.read()
with open(tf_file, "r", encoding="utf-8") as f: with open(tf_file, "r", encoding="utf-8") as f:
actual_tf_dummies = f.read() actual_tf_dummies = f.read()
with open(flax_file, "r", encoding="utf-8") as f:
actual_flax_dummies = f.read()
if sentencepiece_dummies != actual_sentencepiece_dummies: if sentencepiece_dummies != actual_sentencepiece_dummies:
if overwrite: if overwrite:
...@@ -327,6 +376,17 @@ def check_dummies(overwrite=False): ...@@ -327,6 +376,17 @@ def check_dummies(overwrite=False):
"Run `make fix-copies` to fix this.", "Run `make fix-copies` to fix this.",
) )
if flax_dummies != actual_flax_dummies:
if overwrite:
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
with open(flax_file, "w", encoding="utf-8") as f:
f.write(flax_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.",
"Run `make fix-copies` to fix this.",
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment