check_dummies.py 6.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import re


# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"

25
26
27
# Matches is_xxx_available()
_re_backend = re.compile(r"is\_([a-z]*)_available()")
# Matches from xxx import bla
28
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
29
_re_test_backend = re.compile(r"^\s+if\s+is\_[a-z]*\_available\(\)")
30

31

32
33
DUMMY_CONSTANT = """
{0} = None
34
35
"""

36
DUMMY_PRETRAINED_CLASS = """
37
38
class {0}:
    def __init__(self, *args, **kwargs):
39
        requires_backends(self, {1})
40
41

    @classmethod
42
43
    def from_pretrained(cls, *args, **kwargs):
        requires_backends(cls, {1})
44
45
"""

46
DUMMY_CLASS = """
47
48
class {0}:
    def __init__(self, *args, **kwargs):
49
        requires_backends(self, {1})
50
51
"""

52
DUMMY_FUNCTION = """
53
def {0}(*args, **kwargs):
54
    requires_backends({0}, {1})
55
56
57
"""


58
59
60
61
62
63
64
65
66
def find_backend(line):
    """Find one (or multiple) backend in a code line of the init."""
    if _re_test_backend.search(line) is None:
        return None
    backends = [b[0] for b in _re_backend.findall(line)]
    backends.sort()
    return "_and_".join(backends)


67
def read_init():
Patrick von Platen's avatar
Patrick von Platen committed
68
    """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
69
    with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
70
71
        lines = f.readlines()

72
    # Get to the point we do the actual imports for type checking
73
    line_index = 0
74
    while not lines[line_index].startswith("if TYPE_CHECKING"):
75
76
        line_index += 1

77
78
79
    backend_specific_objects = {}
    # Go through the end of the file
    while line_index < len(lines):
80
        # If the line is an if is_backend_available, we grab all objects associated.
81
82
        backend = find_backend(lines[line_index])
        if backend is not None:
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
            line_index += 1

            objects = []
            # Until we unindent, add backend objects to the list
            while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
                line = lines[line_index]
                single_line_import_search = _re_single_line_import.search(line)
                if single_line_import_search is not None:
                    objects.extend(single_line_import_search.groups()[0].split(", "))
                elif line.startswith(" " * 12):
                    objects.append(line[12:-2])
                line_index += 1

            backend_specific_objects[backend] = objects
        else:
            line_index += 1
99

100
    return backend_specific_objects
101
102


103
def create_dummy_object(name, backend_name):
Patrick von Platen's avatar
Patrick von Platen committed
104
    """Create the code for the dummy object corresponding to `name`."""
105
    _pretrained = [
106
107
        "Config",
        "ForCausalLM",
108
109
110
        "ForConditionalGeneration",
        "ForMaskedLM",
        "ForMultipleChoice",
111
        "ForObjectDetection",
112
        "ForQuestionAnswering",
113
        "ForSegmentation",
114
115
116
117
118
119
120
121
        "ForSequenceClassification",
        "ForTokenClassification",
        "Model",
        "Tokenizer",
    ]
    if name.isupper():
        return DUMMY_CONSTANT.format(name)
    elif name.islower():
122
        return DUMMY_FUNCTION.format(name, backend_name)
123
124
125
126
127
128
129
    else:
        is_pretrained = False
        for part in _pretrained:
            if part in name:
                is_pretrained = True
                break
        if is_pretrained:
130
            return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
131
        else:
132
            return DUMMY_CLASS.format(name, backend_name)
133
134
135


def create_dummy_files():
Patrick von Platen's avatar
Patrick von Platen committed
136
    """Create the content of the dummy files."""
137
138
139
    backend_specific_objects = read_init()
    # For special correspondence backend to module name as used in the function requires_modulename
    dummy_files = {}
140

141
    for backend, objects in backend_specific_objects.items():
142
        backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
143
        dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
144
        dummy_file += "from ..file_utils import requires_backends\n\n"
145
146
        dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
        dummy_files[backend] = dummy_file
147

148
    return dummy_files
149
150
151


def check_dummies(overwrite=False):
Patrick von Platen's avatar
Patrick von Platen committed
152
    """Check if the dummy files are up to date and maybe `overwrite` with the right content."""
153
154
155
156
157
    dummy_files = create_dummy_files()
    # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
    short_names = {"torch": "pt"}

    # Locate actual dummy modules and read their content.
158
    path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
159
160
161
162
163
164
165
    dummy_file_paths = {
        backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
        for backend in dummy_files.keys()
    }

    actual_dummies = {}
    for backend, file_path in dummy_file_paths.items():
166
167
168
169
170
        if os.path.isfile(file_path):
            with open(file_path, "r", encoding="utf-8", newline="\n") as f:
                actual_dummies[backend] = f.read()
        else:
            actual_dummies[backend] = ""
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

    for backend in dummy_files.keys():
        if dummy_files[backend] != actual_dummies[backend]:
            if overwrite:
                print(
                    f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
                    "__init__ has new objects."
                )
                with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
                    f.write(dummy_files[backend])
            else:
                raise ValueError(
                    "The main __init__ has objects that are not present in "
                    f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
                    "to fix this."
                )
187

188
189
190
191
192
193
194

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
    args = parser.parse_args()

    check_dummies(args.fix_and_overwrite)