check_dummies.py 8.21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Sylvain Gugger's avatar
Sylvain Gugger committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.

Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
to access one of their methods.

Usage (from the root of the repo):

Check that the dummy files are up to date (used in `make repo-consistency`):

```bash
python utils/check_dummies.py
```
29

Sylvain Gugger's avatar
Sylvain Gugger committed
30
31
32
33
34
35
Update the dummy files if needed (used in `make fix-copies`):

```bash
python utils/check_dummies.py --fix_and_overwrite
```
"""
36
37
38
import argparse
import os
import re
Sylvain Gugger's avatar
Sylvain Gugger committed
39
from typing import Dict, List, Optional
40
41
42
43
44
45


# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_dummies.py
PATH_TO_TRANSFORMERS = "src/transformers"

46
# Matches is_xxx_available()
47
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
48
# Matches from xxx import bla
49
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
Sylvain Gugger's avatar
Sylvain Gugger committed
50
# Matches if not is_xxx_available()
51
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
52

53

Sylvain Gugger's avatar
Sylvain Gugger committed
54
# Template for the dummy objects.
55
56
DUMMY_CONSTANT = """
{0} = None
57
58
"""

Sylvain Gugger's avatar
Sylvain Gugger committed
59

60
DUMMY_CLASS = """
Sylvain Gugger's avatar
Sylvain Gugger committed
61
62
63
class {0}(metaclass=DummyObject):
    _backends = {1}

64
    def __init__(self, *args, **kwargs):
65
        requires_backends(self, {1})
66
67
"""

Sylvain Gugger's avatar
Sylvain Gugger committed
68

69
DUMMY_FUNCTION = """
70
def {0}(*args, **kwargs):
71
    requires_backends({0}, {1})
72
73
74
"""


Sylvain Gugger's avatar
Sylvain Gugger committed
75
76
77
78
79
80
81
82
83
84
85
86
def find_backend(line: str) -> Optional[str]:
    """
    Find one (or multiple) backend in a code line of the init.

    Args:
        line (`str`): A code line in an init file.

    Returns:
        Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
        contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
        `xxx_and_yyy` for instance).
    """
87
88
89
90
91
92
93
    if _re_test_backend.search(line) is None:
        return None
    backends = [b[0] for b in _re_backend.findall(line)]
    backends.sort()
    return "_and_".join(backends)


Sylvain Gugger's avatar
Sylvain Gugger committed
94
95
96
97
98
99
100
def read_init() -> Dict[str, List[str]]:
    """
    Read the init and extract backend-specific objects.

    Returns:
        Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
    """
101
    with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
102
103
        lines = f.readlines()

104
    # Get to the point we do the actual imports for type checking
105
    line_index = 0
106
    while not lines[line_index].startswith("if TYPE_CHECKING"):
107
108
        line_index += 1

109
110
111
    backend_specific_objects = {}
    # Go through the end of the file
    while line_index < len(lines):
112
        # If the line is an if is_backend_available, we grab all objects associated.
113
114
        backend = find_backend(lines[line_index])
        if backend is not None:
115
116
            while not lines[line_index].startswith("    else:"):
                line_index += 1
117
118
119
120
121
122
123
124
            line_index += 1

            objects = []
            # Until we unindent, add backend objects to the list
            while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
                line = lines[line_index]
                single_line_import_search = _re_single_line_import.search(line)
                if single_line_import_search is not None:
Sylvain Gugger's avatar
Sylvain Gugger committed
125
                    # Single-line imports
126
127
                    objects.extend(single_line_import_search.groups()[0].split(", "))
                elif line.startswith(" " * 12):
Sylvain Gugger's avatar
Sylvain Gugger committed
128
                    # Multiple-line imports (with 3 indent level)
129
130
131
132
133
134
                    objects.append(line[12:-2])
                line_index += 1

            backend_specific_objects[backend] = objects
        else:
            line_index += 1
135

136
    return backend_specific_objects
137
138


Sylvain Gugger's avatar
Sylvain Gugger committed
139
140
141
142
143
144
145
146
147
148
149
def create_dummy_object(name: str, backend_name: str) -> str:
    """
    Create the code for a dummy object.

    Args:
        name (`str`): The name of the object.
        backend_name (`str`): The name of the backend required for that object.

    Returns:
        `str`: The code of the dummy object.
    """
150
151
152
    if name.isupper():
        return DUMMY_CONSTANT.format(name)
    elif name.islower():
153
        return DUMMY_FUNCTION.format(name, backend_name)
154
    else:
Sylvain Gugger's avatar
Sylvain Gugger committed
155
        return DUMMY_CLASS.format(name, backend_name)
156
157


Sylvain Gugger's avatar
Sylvain Gugger committed
158
159
160
161
162
163
164
165
166
167
168
169
def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]:
    """
    Create the content of the dummy files.

    Args:
        backend_specific_objects (`Dict[str, List[str]]`, *optional*):
            The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
            `read_init()`.

    Returns:
        `Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
    """
170
171
    if backend_specific_objects is None:
        backend_specific_objects = read_init()
Sylvain Gugger's avatar
Sylvain Gugger committed
172

173
    dummy_files = {}
174

175
    for backend, objects in backend_specific_objects.items():
176
        backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
177
        dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
178
        dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
179
180
        dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
        dummy_files[backend] = dummy_file
181

182
    return dummy_files
183
184


Sylvain Gugger's avatar
Sylvain Gugger committed
185
186
187
188
189
190
191
192
193
def check_dummies(overwrite: bool = False):
    """
    Check if the dummy files are up to date and maybe `overwrite` with the right content.

    Args:
        overwrite (`bool`, *optional*, default to `False`):
            Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
            when `overwrite=False`.
    """
194
    dummy_files = create_dummy_files()
Sylvain Gugger's avatar
Sylvain Gugger committed
195
    # For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
196
197
198
    short_names = {"torch": "pt"}

    # Locate actual dummy modules and read their content.
199
    path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
200
201
202
203
204
205
206
    dummy_file_paths = {
        backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
        for backend in dummy_files.keys()
    }

    actual_dummies = {}
    for backend, file_path in dummy_file_paths.items():
207
208
209
210
211
        if os.path.isfile(file_path):
            with open(file_path, "r", encoding="utf-8", newline="\n") as f:
                actual_dummies[backend] = f.read()
        else:
            actual_dummies[backend] = ""
212

Sylvain Gugger's avatar
Sylvain Gugger committed
213
    # Compare actual with what they should be.
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    for backend in dummy_files.keys():
        if dummy_files[backend] != actual_dummies[backend]:
            if overwrite:
                print(
                    f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
                    "__init__ has new objects."
                )
                with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
                    f.write(dummy_files[backend])
            else:
                raise ValueError(
                    "The main __init__ has objects that are not present in "
                    f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
                    "to fix this."
                )
229

230
231
232
233
234
235
236

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
    args = parser.parse_args()

    check_dummies(args.fix_and_overwrite)