Unverified Commit 758ed333 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Transformers fast import part 2 (#9446)



* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling
Co-authored-by: default avatarLysandreJik <lysandre.debut@reseau.eseo.fr>
parent a400fe89
...@@ -16,8 +16,8 @@ import argparse ...@@ -16,8 +16,8 @@ import argparse
import torch import torch
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert from ...utils import logging
from transformers.utils import logging from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -19,8 +19,9 @@ import argparse ...@@ -19,8 +19,9 @@ import argparse
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging from ...utils import logging
from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -22,8 +22,8 @@ import tensorflow as tf ...@@ -22,8 +22,8 @@ import tensorflow as tf
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params from .configuration_pegasus import DEFAULTS, task_specific_params
PATTERNS = [ PATTERNS = [
......
...@@ -19,8 +19,6 @@ import argparse ...@@ -19,8 +19,6 @@ import argparse
import torch import torch
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here # transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively # original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
from transformers_old.modeling_prophetnet import ( from transformers_old.modeling_prophetnet import (
...@@ -30,6 +28,8 @@ from transformers_old.modeling_xlm_prophetnet import ( ...@@ -30,6 +28,8 @@ from transformers_old.modeling_xlm_prophetnet import (
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld, XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
) )
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -21,8 +21,8 @@ import pickle ...@@ -21,8 +21,8 @@ import pickle
import numpy as np import numpy as np
import torch import torch
from transformers import ReformerConfig, ReformerModelWithLMHead from ...utils import logging
from transformers.utils import logging from . import ReformerConfig, ReformerModelWithLMHead
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -24,19 +24,9 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel ...@@ -24,19 +24,9 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version from packaging import version
from transformers.models.bert.modeling_bert import ( from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
BertIntermediate, from ...utils import logging
BertLayer, from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.models.roberta.modeling_roberta import (
RobertaConfig,
RobertaForMaskedLM,
RobertaForSequenceClassification,
)
from transformers.utils import logging
if version.parse(fairseq.__version__) < version.parse("0.9.0"): if version.parse(fairseq.__version__) < version.parse("0.9.0"):
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
import argparse import argparse
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 from ...utils import logging
from transformers.utils import logging from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -22,8 +22,6 @@ from typing import Tuple ...@@ -22,8 +22,6 @@ from typing import Tuple
import tensorflow as tf import tensorflow as tf
from transformers.modeling_tf_utils import TFWrappedEmbeddings
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
...@@ -42,6 +40,7 @@ from ...modeling_tf_utils import ( ...@@ -42,6 +40,7 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,
......
...@@ -17,16 +17,16 @@ ...@@ -17,16 +17,16 @@
import argparse import argparse
from transformers.models.tapas.modeling_tapas import ( from ...utils import logging
from . import (
TapasConfig, TapasConfig,
TapasForMaskedLM, TapasForMaskedLM,
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasTokenizer,
load_tf_weights_in_tapas, load_tf_weights_in_tapas,
) )
from transformers.models.tapas.tokenization_tapas import TapasTokenizer
from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -28,9 +28,7 @@ from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union ...@@ -28,9 +28,7 @@ from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union
import numpy as np import numpy as np
from transformers import add_end_docstrings from ...file_utils import add_end_docstrings, is_pandas_available
from ...file_utils import is_pandas_available
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...tokenization_utils_base import ( from ...tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING, ENCODE_KWARGS_DOCSTRING,
......
...@@ -22,16 +22,11 @@ import sys ...@@ -22,16 +22,11 @@ import sys
import torch import torch
import transformers.models.transfo_xl.tokenization_transfo_xl as data_utils from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers import ( from ...utils import logging
CONFIG_NAME, from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
WEIGHTS_NAME, from . import tokenization_transfo_xl as data_utils
TransfoXLConfig, from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl,
)
from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -21,9 +21,9 @@ import json ...@@ -21,9 +21,9 @@ import json
import numpy import numpy
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES from ...utils import logging
from transformers.utils import logging from .tokenization_xlm import VOCAB_FILES_NAMES
logging.set_verbosity_info() logging.set_verbosity_info()
......
...@@ -20,16 +20,15 @@ import os ...@@ -20,16 +20,15 @@ import os
import torch import torch
from transformers import ( from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
CONFIG_NAME, from ...utils import logging
WEIGHTS_NAME, from . import (
XLNetConfig, XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetLMHeadModel, XLNetLMHeadModel,
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
) )
from transformers.utils import logging
GLUE_TASKS_NUM_LABELS = { GLUE_TASKS_NUM_LABELS = {
......
...@@ -24,9 +24,78 @@ ...@@ -24,9 +24,78 @@
## ##
## Put '## COMMENT' to comment on the file. ## Put '## COMMENT' to comment on the file.
# To replace in: "src/transformers/__init__.py"
# Below: " # PyTorch models structure" if generating PyTorch
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"{{cookiecutter.camelcase_modelname}}Layer",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
]
)
{% else %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}Model",
]
)
{% endif -%}
# End.
# Below: " # TensorFlow models structure" if generating TensorFlow
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"TF{{cookiecutter.camelcase_modelname}}ForCausalLM",
"TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"TF{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"TF{{cookiecutter.camelcase_modelname}}Layer",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
)
{% else %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
)
{% endif -%}
# End.
# Below: " # Fast tokenizers"
# Replace with:
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
# End.
# Below: " # Models"
# Replace with:
"models.{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config", "{{cookiecutter.camelcase_modelname}}Tokenizer"],
# End.
# To replace in: "src/transformers/__init__.py" # To replace in: "src/transformers/__init__.py"
# Below: "if is_torch_available():" if generating PyTorch # Below: " if is_torch_available():" if generating PyTorch
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
...@@ -53,7 +122,7 @@ ...@@ -53,7 +122,7 @@
{% endif -%} {% endif -%}
# End. # End.
# Below: "if is_tf_available():" if generating TensorFlow # Below: " if is_tf_available():" if generating TensorFlow
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
...@@ -77,18 +146,25 @@ ...@@ -77,18 +146,25 @@
{% endif -%} {% endif -%}
# End. # End.
# Below: "if is_tokenizers_available():" # Below: " if is_tokenizers_available():"
# Replace with: # Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
# End. # End.
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" # Below: " from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Replace with: # Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
# End. # End.
# To replace in: "src/transformers/models/__init__.py"
# Below: "from . import ("
# Replace with:
{{cookiecutter.lowercase_modelname}},
# End.
# To replace in: "src/transformers/models/auto/configuration_auto.py" # To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here" # Below: "# Add configs here"
# Replace with: # Replace with:
......
...@@ -23,237 +23,79 @@ import re ...@@ -23,237 +23,79 @@ import re
PATH_TO_TRANSFORMERS = "src/transformers" PATH_TO_TRANSFORMERS = "src/transformers"
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_PT_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
"""
DUMMY_PT_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
"""
DUMMY_PT_FUNCTION = """
def {0}(*args, **kwargs):
requires_pytorch({0})
"""
DUMMY_TF_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
"""
DUMMY_TF_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tf(self)
"""
DUMMY_TF_FUNCTION = """ BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
def {0}(*args, **kwargs):
requires_tf({0})
"""
DUMMY_FLAX_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_CLASS = """ DUMMY_CONSTANT = """
class {0}: {0} = None
def __init__(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_FUNCTION = """
def {0}(*args, **kwargs):
requires_flax({0})
"""
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_sentencepiece(self)
"""
DUMMY_SENTENCEPIECE_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
"""
DUMMY_SENTENCEPIECE_FUNCTION = """
def {0}(*args, **kwargs):
requires_sentencepiece({0})
""" """
DUMMY_PRETRAINED_CLASS = """
DUMMY_TOKENIZERS_PRETRAINED_CLASS = """
class {0}: class {0}:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tokenizers(self) requires_{1}(self)
@classmethod @classmethod
def from_pretrained(self, *args, **kwargs): def from_pretrained(self, *args, **kwargs):
requires_tokenizers(self) requires_{1}(self)
""" """
DUMMY_TOKENIZERS_CLASS = """ DUMMY_CLASS = """
class {0}: class {0}:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_tokenizers(self) requires_{1}(self)
""" """
DUMMY_TOKENIZERS_FUNCTION = """ DUMMY_FUNCTION = """
def {0}(*args, **kwargs): def {0}(*args, **kwargs):
requires_tokenizers({0}) requires_{1}({0})
""" """
# Map all these to dummy type
DUMMY_PRETRAINED_CLASS = {
"pt": DUMMY_PT_PRETRAINED_CLASS,
"tf": DUMMY_TF_PRETRAINED_CLASS,
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
}
DUMMY_CLASS = {
"pt": DUMMY_PT_CLASS,
"tf": DUMMY_TF_CLASS,
"flax": DUMMY_FLAX_CLASS,
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
"tokenizers": DUMMY_TOKENIZERS_CLASS,
}
DUMMY_FUNCTION = {
"pt": DUMMY_PT_FUNCTION,
"tf": DUMMY_TF_FUNCTION,
"flax": DUMMY_FLAX_FUNCTION,
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
}
def read_init(): def read_init():
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """ """ Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines() lines = f.readlines()
# Get to the point we do the actual imports for type checking
line_index = 0 line_index = 0
# Find where the SentencePiece imports begin while not lines[line_index].startswith("if TYPE_CHECKING"):
sentencepiece_objects = []
while not lines[line_index].startswith("if is_sentencepiece_available():"):
line_index += 1
line_index += 1 line_index += 1
# Until we unindent, add SentencePiece objects to the list backend_specific_objects = {}
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): # Go through the end of the file
line = lines[line_index] while line_index < len(lines):
search = _re_single_line_import.search(line) # If the line is an if is_backemd_available, we grab all objects associated.
if search is not None: if _re_test_backend.search(lines[line_index]) is not None:
sentencepiece_objects += search.groups()[0].split(", ") backend = _re_test_backend.search(lines[line_index]).groups()[0]
elif line.startswith(" "):
sentencepiece_objects.append(line[8:-2])
line_index += 1 line_index += 1
# Find where the Tokenizers imports begin # Ignore if backend isn't tracked for dummies.
tokenizers_objects = [] if backend not in BACKENDS:
while not lines[line_index].startswith("if is_tokenizers_available():"): continue
line_index += 1
line_index += 1
# Until we unindent, add Tokenizers objects to the list objects = []
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): # Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line = lines[line_index] line = lines[line_index]
search = _re_single_line_import.search(line) single_line_import_search = _re_single_line_import.search(line)
if search is not None: if single_line_import_search is not None:
tokenizers_objects += search.groups()[0].split(", ") objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" "): elif line.startswith(" " * 12):
tokenizers_objects.append(line[8:-2]) objects.append(line[12:-2])
line_index += 1
# Find where the PyTorch imports begin
pt_objects = []
while not lines[line_index].startswith("if is_torch_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
pt_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
pt_objects.append(line[8:-2])
line_index += 1
# Find where the TF imports begin
tf_objects = []
while not lines[line_index].startswith("if is_tf_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
tf_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
tf_objects.append(line[8:-2])
line_index += 1
# Find where the FLAX imports begin
flax_objects = []
while not lines[line_index].startswith("if is_flax_available():"):
line_index += 1
line_index += 1 line_index += 1
# Until we unindent, add PyTorch objects to the list backend_specific_objects[backend] = objects
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): else:
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
flax_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
flax_objects.append(line[8:-2])
line_index += 1 line_index += 1
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects return backend_specific_objects
def create_dummy_object(name, type="pt"): def create_dummy_object(name, backend_name):
""" Create the code for the dummy object corresponding to `name`.""" """ Create the code for the dummy object corresponding to `name`."""
_pretrained = [ _pretrained = [
"Config" "ForCausalLM", "Config" "ForCausalLM",
...@@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"): ...@@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"):
"Model", "Model",
"Tokenizer", "Tokenizer",
] ]
assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
if name.isupper(): if name.isupper():
return DUMMY_CONSTANT.format(name) return DUMMY_CONSTANT.format(name)
elif name.islower(): elif name.islower():
return (DUMMY_FUNCTION[type]).format(name) return DUMMY_FUNCTION.format(name, backend_name)
else: else:
is_pretrained = False is_pretrained = False
for part in _pretrained: for part in _pretrained:
...@@ -278,113 +119,60 @@ def create_dummy_object(name, type="pt"): ...@@ -278,113 +119,60 @@ def create_dummy_object(name, type="pt"):
is_pretrained = True is_pretrained = True
break break
if is_pretrained: if is_pretrained:
template = DUMMY_PRETRAINED_CLASS[type] return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
else: else:
template = DUMMY_CLASS[type] return DUMMY_CLASS.format(name, backend_name)
return template.format(name)
def create_dummy_files(): def create_dummy_files():
""" Create the content of the dummy files. """ """ Create the content of the dummy files. """
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init() backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" module_names = {"torch": "pytorch"}
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n" dummy_files = {}
sentencepiece_dummies += "\n".join([create_dummy_object(o, type="sentencepiece") for o in sentencepiece_objects])
tokenizers_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
tokenizers_dummies += "from ..file_utils import requires_tokenizers\n\n"
tokenizers_dummies += "\n".join([create_dummy_object(o, type="tokenizers") for o in tokenizers_objects])
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" for backend, objects in backend_specific_objects.items():
pt_dummies += "from ..file_utils import requires_pytorch\n\n" backend_name = module_names.get(backend, backend)
pt_dummies += "\n".join([create_dummy_object(o, type="pt") for o in pt_objects]) dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" return dummy_files
tf_dummies += "from ..file_utils import requires_tf\n\n"
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
flax_dummies += "from ..file_utils import requires_flax\n\n"
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
def check_dummies(overwrite=False): def check_dummies(overwrite=False):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """ """ Check if the dummy files are up to date and maybe `overwrite` with the right content. """
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files() dummy_files = create_dummy_files()
path = os.path.join(PATH_TO_TRANSFORMERS, "utils") # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py") short_names = {"torch": "pt"}
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
pt_file = os.path.join(path, "dummy_pt_objects.py")
tf_file = os.path.join(path, "dummy_tf_objects.py")
flax_file = os.path.join(path, "dummy_flax_objects.py")
with open(sentencepiece_file, "r", encoding="utf-8", newline="\n") as f:
actual_sentencepiece_dummies = f.read()
with open(tokenizers_file, "r", encoding="utf-8", newline="\n") as f:
actual_tokenizers_dummies = f.read()
with open(pt_file, "r", encoding="utf-8", newline="\n") as f:
actual_pt_dummies = f.read()
with open(tf_file, "r", encoding="utf-8", newline="\n") as f:
actual_tf_dummies = f.read()
with open(flax_file, "r", encoding="utf-8", newline="\n") as f:
actual_flax_dummies = f.read()
if sentencepiece_dummies != actual_sentencepiece_dummies:
if overwrite:
print("Updating transformers.utils.dummy_sentencepiece_objects.py as the main __init__ has new objects.")
with open(sentencepiece_file, "w", encoding="utf-8", newline="\n") as f:
f.write(sentencepiece_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_sentencepiece_objects.py.",
"Run `make fix-copies` to fix this.",
)
if tokenizers_dummies != actual_tokenizers_dummies:
if overwrite:
print("Updating transformers.utils.dummy_tokenizers_objects.py as the main __init__ has new objects.")
with open(tokenizers_file, "w", encoding="utf-8", newline="\n") as f:
f.write(tokenizers_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_tokenizers_objects.py.",
"Run `make fix-copies` to fix this.",
)
if pt_dummies != actual_pt_dummies: # Locate actual dummy modules and read their content.
if overwrite: path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.") dummy_file_paths = {
with open(pt_file, "w", encoding="utf-8", newline="\n") as f: backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
f.write(pt_dummies) for backend in dummy_files.keys()
else: }
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.", actual_dummies = {}
"Run `make fix-copies` to fix this.", for backend, file_path in dummy_file_paths.items():
) with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_dummies[backend] = f.read()
if tf_dummies != actual_tf_dummies:
for backend in dummy_files.keys():
if dummy_files[backend] != actual_dummies[backend]:
if overwrite: if overwrite:
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.") print(
with open(tf_file, "w", encoding="utf-8", newline="\n") as f: f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
f.write(tf_dummies) "__init__ has new objects."
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
"Run `make fix-copies` to fix this.",
) )
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
if flax_dummies != actual_flax_dummies: f.write(dummy_files[backend])
if overwrite:
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
with open(flax_file, "w", encoding="utf-8", newline="\n") as f:
f.write(flax_dummies)
else: else:
raise ValueError( raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.", "The main __init__ has objects that are not present in "
"Run `make fix-copies` to fix this.", f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
"to fix this."
) )
......
...@@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
def ignore_undocumented(name): def ignore_undocumented(name):
"""Rules to determine if `name` should be undocumented.""" """Rules to determine if `name` should be undocumented."""
# NOT DOCUMENTED ON PURPOSE. # NOT DOCUMENTED ON PURPOSE.
# Magic attributes are not documented.
if name.startswith("__"):
return True
# Constants uppercase are not documented. # Constants uppercase are not documented.
if name.isupper(): if name.isupper():
return True return True
...@@ -459,7 +456,9 @@ def ignore_undocumented(name): ...@@ -459,7 +456,9 @@ def ignore_undocumented(name):
def check_all_objects_are_documented(): def check_all_objects_are_documented():
""" Check all models are properly documented.""" """ Check all models are properly documented."""
documented_objs = find_all_documented_objects() documented_objs = find_all_documented_objects()
undocumented_objs = [c for c in dir(transformers) if c not in documented_objs and not ignore_undocumented(c)] modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
if len(undocumented_objs) > 0: if len(undocumented_objs) > 0:
raise Exception( raise Exception(
"The following objects are in the public init so should be documented:\n - " "The following objects are in the public init so should be documented:\n - "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment