tests_fetcher.py 30.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import collections
Yih-Dar's avatar
Yih-Dar committed
18
import json
19
20
21
22
23
24
25
26
27
28
29
import os
import re
from contextlib import contextmanager
from pathlib import Path

from git import Repo


# This script is intended to be run from the root of the repo but you can adapt this constant if you need to.
PATH_TO_TRANFORMERS = "."

30
31
32
33
34
# A temporary way to trigger all pipeline tests contained in model test files after PR #21516
all_model_test_files = [str(x) for x in Path("tests/models/").glob("**/**/test_modeling_*.py")]

all_pipeline_test_files = [str(x) for x in Path("tests/pipelines/").glob("**/test_pipelines_*.py")]

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

@contextmanager
def checkout_commit(repo, commit_id):
    """
    Context manager that checks out a commit in the repo.
    """
    current_head = repo.head.commit if repo.head.is_detached else repo.head.ref

    try:
        repo.git.checkout(commit_id)
        yield

    finally:
        repo.git.checkout(current_head)


def clean_code(content):
    """
    Remove docstrings, empty line or comments from `content`.
    """
    # fmt: off
    # Remove docstrings by splitting on triple " then triple ':
    splits = content.split('\"\"\"')
    content = "".join(splits[::2])
    splits = content.split("\'\'\'")
    # fmt: on
    content = "".join(splits[::2])

    # Remove empty lines and comments
    lines_to_keep = []
    for line in content.split("\n"):
        # remove anything that is after a # sign.
        line = re.sub("#.*$", "", line)
        if len(line) == 0 or line.isspace():
            continue
        lines_to_keep.append(line)
    return "\n".join(lines_to_keep)


Yih-Dar's avatar
Yih-Dar committed
74
75
76
77
78
79
80
81
82
83
84
85
def get_all_tests():
    """
    Return a list of paths to all test folders and files under `tests`. All paths are rooted at `tests`.

    - folders under `tests`: `tokenization`, `pipelines`, etc. The folder `models` is excluded.
    - folders under `tests/models`: `bert`, `gpt2`, etc.
    - test files under `tests`: `test_modeling_common.py`, `test_tokenization_common.py`, etc.
    """
    test_root_dir = os.path.join(PATH_TO_TRANFORMERS, "tests")

    # test folders/files directly under `tests` folder
    tests = os.listdir(test_root_dir)
86
    tests = sorted(filter(lambda x: os.path.isdir(x) or x.startswith("tests/test_"), [f"tests/{x}" for x in tests]))
Yih-Dar's avatar
Yih-Dar committed
87
88
89

    # model specific test folders
    model_tests_folders = os.listdir(os.path.join(test_root_dir, "models"))
90
    model_test_folders = sorted(filter(os.path.isdir, [f"tests/models/{x}" for x in model_tests_folders]))
Yih-Dar's avatar
Yih-Dar committed
91
92
93
94
95
96
97

    tests.remove("tests/models")
    tests = model_test_folders + tests

    return tests


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def diff_is_docstring_only(repo, branching_point, filename):
    """
    Check if the diff is only in docstrings in a filename.
    """
    with checkout_commit(repo, branching_point):
        with open(filename, "r", encoding="utf-8") as f:
            old_content = f.read()

    with open(filename, "r", encoding="utf-8") as f:
        new_content = f.read()

    old_content_clean = clean_code(old_content)
    new_content_clean = clean_code(new_content)

    return old_content_clean == new_content_clean


115
def get_modified_python_files(diff_with_last_commit=False):
116
    """
117
118
    Return a list of python files that have been modified between:

119
    - the current head and the main branch if `diff_with_last_commit=False` (default)
120
    - the current head and its parent commit otherwise.
121
122
123
    """
    repo = Repo(PATH_TO_TRANFORMERS)

124
    if not diff_with_last_commit:
125
        print(f"main is at {repo.refs.main.commit}")
126
127
        print(f"Current head is at {repo.head.commit}")

128
        branching_commits = repo.merge_base(repo.refs.main, repo.head)
129
130
131
132
        for commit in branching_commits:
            print(f"Branching commit: {commit}")
        return get_diff(repo, repo.head.commit, branching_commits)
    else:
133
        print(f"main is at {repo.head.commit}")
134
135
136
137
        parent_commits = repo.head.commit.parents
        for commit in parent_commits:
            print(f"Parent commit: {commit}")
        return get_diff(repo, repo.head.commit, parent_commits)
138
139


140
141
142
143
def get_diff(repo, base_commit, commits):
    """
    Get's the diff between one or several commits and the head of the repository.
    """
144
145
    print("\n### DIFF ###\n")
    code_diff = []
146
147
    for commit in commits:
        for diff_obj in commit.diff(base_commit):
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            # We always add new python files
            if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
                code_diff.append(diff_obj.b_path)
            # We check that deleted python files won't break corresponding tests.
            elif diff_obj.change_type == "D" and diff_obj.a_path.endswith(".py"):
                code_diff.append(diff_obj.a_path)
            # Now for modified files
            elif diff_obj.change_type in ["M", "R"] and diff_obj.b_path.endswith(".py"):
                # In case of renames, we'll look at the tests using both the old and new name.
                if diff_obj.a_path != diff_obj.b_path:
                    code_diff.extend([diff_obj.a_path, diff_obj.b_path])
                else:
                    # Otherwise, we check modifications are in code and not docstrings.
                    if diff_is_docstring_only(repo, commit, diff_obj.b_path):
                        print(f"Ignoring diff in {diff_obj.b_path} as it only concerns docstrings or comments.")
                    else:
                        code_diff.append(diff_obj.a_path)

    return code_diff


def get_module_dependencies(module_fname):
    """
    Get the dependencies of a module.
    """
    with open(os.path.join(PATH_TO_TRANFORMERS, module_fname), "r", encoding="utf-8") as f:
        content = f.read()

    module_parts = module_fname.split(os.path.sep)
    imported_modules = []

    # Let's start with relative imports
180
181
    relative_imports = re.findall(r"from\s+(\.+\S+)\s+import\s+([^\n]+)\n", content)
    relative_imports = [mod for mod, imp in relative_imports if "# tests_ignore" not in imp]
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    for imp in relative_imports:
        level = 0
        while imp.startswith("."):
            imp = imp[1:]
            level += 1

        if len(imp) > 0:
            dep_parts = module_parts[: len(module_parts) - level] + imp.split(".")
        else:
            dep_parts = module_parts[: len(module_parts) - level] + ["__init__.py"]
        imported_module = os.path.sep.join(dep_parts)
        # We ignore the main init import as it's only for the __version__ that it's done
        # and it would add everything as a dependency.
        if not imported_module.endswith("transformers/__init__.py"):
            imported_modules.append(imported_module)

    # Let's continue with direct imports
    # The import from the transformers module are ignored for the same reason we ignored the
    # main init before.
201
202
    direct_imports = re.findall(r"from\s+transformers\.(\S+)\s+import\s+([^\n]+)\n", content)
    direct_imports = [mod for mod, imp in direct_imports if "# tests_ignore" not in imp]
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    for imp in direct_imports:
        import_parts = imp.split(".")
        dep_parts = ["src", "transformers"] + import_parts
        imported_modules.append(os.path.sep.join(dep_parts))

    # Now let's just check that we have proper module files, or append an init for submodules
    dependencies = []
    for imported_module in imported_modules:
        if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f"{imported_module}.py")):
            dependencies.append(f"{imported_module}.py")
        elif os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, imported_module)) and os.path.isfile(
            os.path.sep.join([PATH_TO_TRANFORMERS, imported_module, "__init__.py"])
        ):
            dependencies.append(os.path.sep.join([imported_module, "__init__.py"]))
    return dependencies


Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
224
225
226
227
def get_test_dependencies(test_fname):
    """
    Get the dependencies of a test file.
    """
    with open(os.path.join(PATH_TO_TRANFORMERS, test_fname), "r", encoding="utf-8") as f:
        content = f.read()

    # Tests only have relative imports for other test files
228
229
    # TODO Sylvain: handle relative imports cleanly
    relative_imports = re.findall(r"from\s+(\.\S+)\s+import\s+([^\n]+)\n", content)
Sylvain Gugger's avatar
Sylvain Gugger committed
230
    relative_imports = [test for test, imp in relative_imports if "# tests_ignore" not in imp]
231

232
233
234
235
236
237
238
239
    def _convert_relative_import_to_file(relative_import):
        level = 0
        while relative_import.startswith("."):
            level += 1
            relative_import = relative_import[1:]

        directory = os.path.sep.join(test_fname.split(os.path.sep)[:-level])
        return os.path.join(directory, f"{relative_import.replace('.', os.path.sep)}.py")
240

241
242
    dependencies = [_convert_relative_import_to_file(relative_import) for relative_import in relative_imports]
    return [f for f in dependencies if os.path.isfile(os.path.join(PATH_TO_TRANFORMERS, f))]
Sylvain Gugger's avatar
Sylvain Gugger committed
243
244


245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def create_reverse_dependency_tree():
    """
    Create a list of all edges (a, b) which mean that modifying a impacts b with a going over all module and test files.
    """
    modules = [
        str(f.relative_to(PATH_TO_TRANFORMERS))
        for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
    ]
    module_edges = [(d, m) for m in modules for d in get_module_dependencies(m)]

    tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
    test_edges = [(d, t) for t in tests for d in get_test_dependencies(t)]

    return module_edges + test_edges


def get_tree_starting_at(module, edges):
    """
    Returns the tree starting at a given module following all edges in the following format: [module, [list of edges
    starting at module], [list of edges starting at the preceding level], ...]
    """
    vertices_seen = [module]
    new_edges = [edge for edge in edges if edge[0] == module and edge[1] != module]
    tree = [module]
    while len(new_edges) > 0:
        tree.append(new_edges)
271
        final_vertices = list({edge[1] for edge in new_edges})
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        vertices_seen.extend(final_vertices)
        new_edges = [edge for edge in edges if edge[0] in final_vertices and edge[1] not in vertices_seen]

    return tree


def print_tree_deps_of(module, all_edges=None):
    """
    Prints the tree of modules depending on a given module.
    """
    if all_edges is None:
        all_edges = create_reverse_dependency_tree()
    tree = get_tree_starting_at(module, all_edges)

    # The list of lines is a list of tuples (line_to_be_printed, module)
    # Keeping the modules lets us know where to insert each new lines in the list.
    lines = [(tree[0], tree[0])]
    for index in range(1, len(tree)):
        edges = tree[index]
291
        start_edges = {edge[0] for edge in edges}
292
293

        for start in start_edges:
294
            end_edges = {edge[1] for edge in edges if edge[0] == start}
295
296
297
298
299
300
301
302
303
304
305
            # We will insert all those edges just after the line showing start.
            pos = 0
            while lines[pos][1] != start:
                pos += 1
            lines = lines[: pos + 1] + [(" " * (2 * index) + end, end) for end in end_edges] + lines[pos + 1 :]

    for line in lines:
        # We don't print the refs that where just here to help build lines.
        print(line[0])


306
307
def create_reverse_dependency_map():
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
308
309
    Create the dependency map from module/test filename to the list of modules/tests that depend on it (even
    recursively).
310
311
312
313
314
315
316
317
    """
    modules = [
        str(f.relative_to(PATH_TO_TRANFORMERS))
        for f in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
    ]
    # We grab all the dependencies of each module.
    direct_deps = {m: get_module_dependencies(m) for m in modules}

Sylvain Gugger's avatar
Sylvain Gugger committed
318
319
320
321
322
323
    # We add all the dependencies of each test file
    tests = [str(f.relative_to(PATH_TO_TRANFORMERS)) for f in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/*.py")]
    direct_deps.update({t: get_test_dependencies(t) for t in tests})

    all_files = modules + tests

324
325
326
327
    # This recurses the dependencies
    something_changed = True
    while something_changed:
        something_changed = False
Sylvain Gugger's avatar
Sylvain Gugger committed
328
        for m in all_files:
329
            for d in direct_deps[m]:
330
331
                if d not in direct_deps:
                    raise ValueError(f"KeyError:{d}. From {m}")
332
333
334
335
336
337
338
                for dep in direct_deps[d]:
                    if dep not in direct_deps[m]:
                        direct_deps[m].append(dep)
                        something_changed = True

    # Finally we can build the reverse map.
    reverse_map = collections.defaultdict(list)
Sylvain Gugger's avatar
Sylvain Gugger committed
339
    for m in all_files:
340
341
        if m.endswith("__init__.py"):
            reverse_map[m].extend(direct_deps[m])
342
343
344
345
346
347
348
349
350
        for d in direct_deps[m]:
            reverse_map[d].append(m)

    return reverse_map


# Any module file that has a test name which can't be inferred automatically from its name should go here. A better
# approach is to (re-)name the test file accordingly, and second best to add the correspondence map here.
SPECIAL_MODULE_TO_TEST_MAP = {
351
    "commands/add_new_model_like.py": "utils/test_add_new_model_like.py",
352
    "configuration_utils.py": "test_configuration_common.py",
353
354
    "convert_graph_to_onnx.py": "onnx/test_onnx.py",
    "data/data_collator.py": "trainer/test_data_collator.py",
355
356
357
    "deepspeed.py": "deepspeed/",
    "feature_extraction_sequence_utils.py": "test_sequence_feature_extraction_common.py",
    "feature_extraction_utils.py": "test_feature_extraction_common.py",
358
    "file_utils.py": ["utils/test_file_utils.py", "utils/test_model_output.py"],
359
    "image_processing_utils.py": ["test_image_processing_common.py", "utils/test_image_processing_utils.py"],
360
    "image_transforms.py": "test_image_transforms.py",
361
    "utils/generic.py": ["utils/test_file_utils.py", "utils/test_model_output.py", "utils/test_generic.py"],
362
    "utils/hub.py": "utils/test_hub_utils.py",
363
    "modelcard.py": "utils/test_model_card.py",
364
    "modeling_flax_utils.py": "test_modeling_flax_common.py",
365
366
367
    "modeling_tf_utils.py": ["test_modeling_tf_common.py", "utils/test_modeling_tf_core.py"],
    "modeling_utils.py": ["test_modeling_common.py", "utils/test_offline.py"],
    "models/auto/modeling_auto.py": [
Yih-Dar's avatar
Yih-Dar committed
368
369
370
371
        "models/auto/test_modeling_auto.py",
        "models/auto/test_modeling_tf_pytorch.py",
        "models/bort/test_modeling_bort.py",
        "models/dit/test_modeling_dit.py",
372
    ],
Yih-Dar's avatar
Yih-Dar committed
373
    "models/auto/modeling_flax_auto.py": "models/auto/test_modeling_flax_auto.py",
374
    "models/auto/modeling_tf_auto.py": [
Yih-Dar's avatar
Yih-Dar committed
375
376
377
378
379
380
381
        "models/auto/test_modeling_tf_auto.py",
        "models/auto/test_modeling_tf_pytorch.py",
        "models/bort/test_modeling_tf_bort.py",
    ],
    "models/gpt2/modeling_gpt2.py": [
        "models/gpt2/test_modeling_gpt2.py",
        "models/megatron_gpt2/test_modeling_megatron_gpt2.py",
382
    ],
383
384
385
386
    "models/dpt/modeling_dpt.py": [
        "models/dpt/test_modeling_dpt.py",
        "models/dpt/test_modeling_dpt_hybrid.py",
    ],
387
388
    "optimization.py": "optimization/test_optimization.py",
    "optimization_tf.py": "optimization/test_optimization_tf.py",
389
390
    "pipelines/__init__.py": all_pipeline_test_files + all_model_test_files,
    "pipelines/base.py": all_pipeline_test_files + all_model_test_files,
391
    "pipelines/text2text_generation.py": [
392
393
394
395
396
397
398
399
400
401
402
403
        "pipelines/test_pipelines_text2text_generation.py",
        "pipelines/test_pipelines_summarization.py",
        "pipelines/test_pipelines_translation.py",
    ],
    "pipelines/zero_shot_classification.py": "pipelines/test_pipelines_zero_shot.py",
    "testing_utils.py": "utils/test_skip_decorators.py",
    "tokenization_utils.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
    "tokenization_utils_base.py": ["test_tokenization_common.py", "tokenization/test_tokenization_utils.py"],
    "tokenization_utils_fast.py": [
        "test_tokenization_common.py",
        "tokenization/test_tokenization_utils.py",
        "tokenization/test_tokenization_fast.py",
404
405
    ],
    "trainer.py": [
406
        "trainer/test_trainer.py",
407
        "extended/test_trainer_ext.py",
408
409
        "trainer/test_trainer_distributed.py",
        "trainer/test_trainer_tpu.py",
410
    ],
411
412
    "train_pt_utils.py": "trainer/test_trainer_utils.py",
    "utils/versions.py": "utils/test_versions_utils.py",
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
}


def module_to_test_file(module_fname):
    """
    Returns the name of the file(s) where `module_fname` is tested.
    """
    splits = module_fname.split(os.path.sep)

    # Special map has priority
    short_name = os.path.sep.join(splits[2:])
    if short_name in SPECIAL_MODULE_TO_TEST_MAP:
        test_file = SPECIAL_MODULE_TO_TEST_MAP[short_name]
        if isinstance(test_file, str):
            return f"tests/{test_file}"
        return [f"tests/{f}" for f in test_file]

    module_name = splits[-1]
    # Fast tokenizers are tested in the same file as the slow ones.
    if module_name.endswith("_fast.py"):
        module_name = module_name.replace("_fast.py", ".py")

    # Special case for pipelines submodules
    if len(splits) >= 2 and splits[-2] == "pipelines":
437
        default_test_file = f"tests/pipelines/test_pipelines_{module_name}"
438
        return [default_test_file] + all_model_test_files
439
440
    # Special case for benchmarks submodules
    elif len(splits) >= 2 and splits[-2] == "benchmark":
441
        return ["tests/benchmark/test_benchmark.py", "tests/benchmark/test_benchmark_tf.py"]
442
443
    # Special case for commands submodules
    elif len(splits) >= 2 and splits[-2] == "commands":
444
        return "tests/utils/test_cli.py"
445
446
    # Special case for onnx submodules
    elif len(splits) >= 2 and splits[-2] == "onnx":
447
        return ["tests/onnx/test_features.py", "tests/onnx/test_onnx.py", "tests/onnx/test_onnx_v2.py"]
448
449
    # Special case for utils (not the one in src/transformers, the ones at the root of the repo).
    elif len(splits) > 0 and splits[0] == "utils":
450
        default_test_file = f"tests/repo_utils/test_{module_name}"
451
    elif len(splits) > 4 and splits[2] == "models":
Yih-Dar's avatar
Yih-Dar committed
452
        default_test_file = f"tests/models/{splits[3]}/test_{module_name}"
453
454
455
456
    elif len(splits) > 2 and splits[2].startswith("generation"):
        default_test_file = f"tests/generation/test_{module_name}"
    elif len(splits) > 2 and splits[2].startswith("trainer"):
        default_test_file = f"tests/trainer/test_{module_name}"
457
    else:
458
        default_test_file = f"tests/utils/test_{module_name}"
459
460
461
462
463
464
465
466
467
468
469
470
471
472

    if os.path.isfile(default_test_file):
        return default_test_file

    # Processing -> processor
    if "processing" in default_test_file:
        test_file = default_test_file.replace("processing", "processor")
        if os.path.isfile(test_file):
            return test_file


# This list contains the list of test files we expect never to be launched from a change in a module/util. Those are
# launched separately.
EXPECTED_TEST_FILES_NEVER_TOUCHED = [
473
474
    "tests/generation/test_framework_agnostic.py",  # Mixins inherited by actual test classes
    "tests/mixed_int8/test_mixed_int8.py",  # Mixed-int8 bitsandbytes test
475
    "tests/pipelines/test_pipelines_common.py",  # Actually checked by the pipeline based file
476
477
478
    "tests/sagemaker/test_single_node_gpu.py",  # SageMaker test
    "tests/sagemaker/test_multi_node_model_parallel.py",  # SageMaker test
    "tests/sagemaker/test_multi_node_data_parallel.py",  # SageMaker test
479
    "tests/test_pipeline_mixin.py",  # Contains no test of its own (only the common tester class)
480
    "tests/utils/test_doc_samples.py",  # Doc tests
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
]


def _print_list(l):
    return "\n".join([f"- {f}" for f in l])


def sanity_check():
    """
    Checks that all test files can be touched by a modification in at least one module/utils. This test ensures that
    newly-added test files are properly mapped to some module or utils, so they can be run by the CI.
    """
    # Grab all module and utils
    all_files = [
        str(p.relative_to(PATH_TO_TRANFORMERS))
        for p in (Path(PATH_TO_TRANFORMERS) / "src/transformers").glob("**/*.py")
    ]
    all_files += [
        str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "utils").glob("**/*.py")
    ]

    # Compute all the test files we get from those.
    test_files_found = []
    for f in all_files:
        test_f = module_to_test_file(f)
        if test_f is not None:
            if isinstance(test_f, str):
                test_files_found.append(test_f)
            else:
                test_files_found.extend(test_f)

    # Some of the test files might actually be subfolders so we grab the tests inside.
    test_files = []
    for test_f in test_files_found:
        if os.path.isdir(os.path.join(PATH_TO_TRANFORMERS, test_f)):
            test_files.extend(
                [
                    str(p.relative_to(PATH_TO_TRANFORMERS))
                    for p in (Path(PATH_TO_TRANFORMERS) / test_f).glob("**/test*.py")
                ]
            )
        else:
            test_files.append(test_f)

    # Compare to existing test files
    existing_test_files = [
        str(p.relative_to(PATH_TO_TRANFORMERS)) for p in (Path(PATH_TO_TRANFORMERS) / "tests").glob("**/test*.py")
    ]
    not_touched_test_files = [f for f in existing_test_files if f not in test_files]

    should_be_tested = set(not_touched_test_files) - set(EXPECTED_TEST_FILES_NEVER_TOUCHED)
    if len(should_be_tested) > 0:
        raise ValueError(
            "The following test files are not currently associated with any module or utils files, which means they "
            f"will never get run by the CI:\n{_print_list(should_be_tested)}\n. Make sure the names of these test "
            "files match the name of the module or utils they are testing, or adapt the constant "
            "`SPECIAL_MODULE_TO_TEST_MAP` in `utils/tests_fetcher.py` to add them. If your test file is triggered "
            "separately and is not supposed to be run by the regular CI, add it to the "
            "`EXPECTED_TEST_FILES_NEVER_TOUCHED` constant instead."
        )


Yih-Dar's avatar
Yih-Dar committed
543
def infer_tests_to_run(output_file, diff_with_last_commit=False, filters=None, json_output_file=None):
544
    modified_files = get_modified_python_files(diff_with_last_commit=diff_with_last_commit)
545
546
547
548
549
550
551
552
553
554
    print(f"\n### MODIFIED FILES ###\n{_print_list(modified_files)}")

    # Create the map that will give us all impacted modules.
    impacted_modules_map = create_reverse_dependency_map()
    impacted_files = modified_files.copy()
    for f in modified_files:
        if f in impacted_modules_map:
            impacted_files.extend(impacted_modules_map[f])

    # Remove duplicates
555
    impacted_files = sorted(set(impacted_files))
556
557
558
    print(f"\n### IMPACTED FILES ###\n{_print_list(impacted_files)}")

    # Grab the corresponding test files:
559
560
    if "setup.py" in impacted_files:
        test_files_to_run = ["tests"]
Sylvain Gugger's avatar
Sylvain Gugger committed
561
        repo_utils_launch = True
562
563
564
565
566
567
568
    else:
        # Grab the corresponding test files:
        test_files_to_run = []
        for f in impacted_files:
            # Modified test files are always added
            if f.startswith("tests/"):
                test_files_to_run.append(f)
569
570
            # Example files are tested separately
            elif f.startswith("examples/pytorch"):
571
                test_files_to_run.append("examples/pytorch/test_pytorch_examples.py")
572
                test_files_to_run.append("examples/pytorch/test_accelerate_examples.py")
573
574
            elif f.startswith("examples/tensorflow"):
                test_files_to_run.append("examples/tensorflow/test_tensorflow_examples.py")
Suraj Patil's avatar
Suraj Patil committed
575
            elif f.startswith("examples/flax"):
576
                test_files_to_run.append("examples/flax/test_flax_examples.py")
577
578
579
580
581
582
583
584
585
            else:
                new_tests = module_to_test_file(f)
                if new_tests is not None:
                    if isinstance(new_tests, str):
                        test_files_to_run.append(new_tests)
                    else:
                        test_files_to_run.extend(new_tests)

        # Remove duplicates
586
        test_files_to_run = sorted(set(test_files_to_run))
587
588
589
        # Make sure we did not end up with a test file that was removed
        test_files_to_run = [f for f in test_files_to_run if os.path.isfile(f) or os.path.isdir(f)]
        if filters is not None:
590
            filtered_files = []
591
            for filter in filters:
592
593
                filtered_files.extend([f for f in test_files_to_run if f.startswith(filter)])
            test_files_to_run = filtered_files
Sylvain Gugger's avatar
Sylvain Gugger committed
594
595
596
597
598
599
        repo_utils_launch = any(f.split(os.path.sep)[1] == "repo_utils" for f in test_files_to_run)

    if repo_utils_launch:
        repo_util_file = Path(output_file).parent / "test_repo_utils.txt"
        with open(repo_util_file, "w", encoding="utf-8") as f:
            f.write("tests/repo_utils")
600

601
602
603
604
605
    print(f"\n### TEST TO RUN ###\n{_print_list(test_files_to_run)}")
    if len(test_files_to_run) > 0:
        with open(output_file, "w", encoding="utf-8") as f:
            f.write(" ".join(test_files_to_run))

Yih-Dar's avatar
Yih-Dar committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        # Create a map that maps test categories to test files, i.e. `models/bert` -> [...test_modeling_bert.py, ...]

        # Get all test directories (and some common test files) under `tests` and `tests/models` if `test_files_to_run`
        # contains `tests` (i.e. when `setup.py` is changed).
        if "tests" in test_files_to_run:
            test_files_to_run = get_all_tests()

        if json_output_file is not None:
            test_map = {}
            for test_file in test_files_to_run:
                # `test_file` is a path to a test folder/file, starting with `tests/`. For example,
                #   - `tests/models/bert/test_modeling_bert.py` or `tests/models/bert`
                #   - `tests/trainer/test_trainer.py` or `tests/trainer`
                #   - `tests/test_modeling_common.py`
                names = test_file.split(os.path.sep)
                if names[1] == "models":
                    # take the part like `models/bert` for modeling tests
                    key = "/".join(names[1:3])
                elif len(names) > 2 or not test_file.endswith(".py"):
                    # test folders under `tests` or python files under them
                    # take the part like tokenization, `pipeline`, etc. for other test categories
                    key = "/".join(names[1:2])
                else:
                    # common test files directly under `tests/`
                    key = "common"

                if key not in test_map:
                    test_map[key] = []
                test_map[key].append(test_file)

            # sort the keys & values
            keys = sorted(test_map.keys())
            test_map = {k: " ".join(sorted(test_map[k])) for k in keys}
            with open(json_output_file, "w", encoding="UTF-8") as fp:
                json.dump(test_map, fp, ensure_ascii=False)

642

Sylvain Gugger's avatar
Sylvain Gugger committed
643
644
645
646
647
648
649
650
def filter_tests(output_file, filters):
    """
    Reads the content of the output file and filters out all the tests in a list of given folders.

    Args:
        output_file (`str` or `os.PathLike`): The path to the output file of the tests fetcher.
        filters (`List[str]`): A list of folders to filter.
    """
651
652
653
654
655
656
    if not os.path.isfile(output_file):
        print("No test file found.")
        return
    with open(output_file, "r", encoding="utf-8") as f:
        test_files = f.read().split(" ")

Sylvain Gugger's avatar
Sylvain Gugger committed
657
    if len(test_files) == 0 or test_files == [""]:
658
659
        print("No tests to filter.")
        return
Sylvain Gugger's avatar
Sylvain Gugger committed
660

661
    if test_files == ["tests"]:
Sylvain Gugger's avatar
Sylvain Gugger committed
662
        test_files = [os.path.join("tests", f) for f in os.listdir("tests") if f not in ["__init__.py"] + filters]
663
    else:
Sylvain Gugger's avatar
Sylvain Gugger committed
664
        test_files = [f for f in test_files if f.split(os.path.sep)[1] not in filters]
665
666
667
668
669

    with open(output_file, "w", encoding="utf-8") as f:
        f.write(" ".join(test_files))


670
671
672
673
674
675
676
677
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--sanity_check", action="store_true", help="Only test that all tests and modules are accounted for."
    )
    parser.add_argument(
        "--output_file", type=str, default="test_list.txt", help="Where to store the list of tests to run"
    )
Yih-Dar's avatar
Yih-Dar committed
678
679
680
681
682
683
    parser.add_argument(
        "--json_output_file",
        type=str,
        default="test_map.json",
        help="Where to store the tests to run in a dictionary format mapping test categories to test files",
    )
684
685
686
687
688
689
    parser.add_argument(
        "--diff_with_last_commit",
        action="store_true",
        help="To fetch the tests between the current commit and the last commit",
    )
    parser.add_argument(
690
691
692
693
694
        "--filters",
        type=str,
        nargs="*",
        default=["tests"],
        help="Only keep the test files matching one of those filters.",
695
    )
696
    parser.add_argument(
Sylvain Gugger's avatar
Sylvain Gugger committed
697
        "--filter_tests",
698
        action="store_true",
Sylvain Gugger's avatar
Sylvain Gugger committed
699
        help="Will filter the pipeline/repo utils tests outside of the generated list of tests.",
700
    )
701
702
703
704
705
706
    parser.add_argument(
        "--print_dependencies_of",
        type=str,
        help="Will only print the tree of modules depending on the file passed.",
        default=None,
    )
707
    args = parser.parse_args()
708
709
710
    if args.print_dependencies_of is not None:
        print_tree_deps_of(args.print_dependencies_of)
    elif args.sanity_check:
711
        sanity_check()
Sylvain Gugger's avatar
Sylvain Gugger committed
712
713
    elif args.filter_tests:
        filter_tests(args.output_file, ["pipelines", "repo_utils"])
714
715
    else:
        repo = Repo(PATH_TO_TRANFORMERS)
716
717

        diff_with_last_commit = args.diff_with_last_commit
718
719
        if not diff_with_last_commit and not repo.head.is_detached and repo.head.ref == repo.refs.main:
            print("main branch detected, fetching tests against last commit.")
720
721
722
            diff_with_last_commit = True

        try:
Yih-Dar's avatar
Yih-Dar committed
723
724
725
726
727
728
            infer_tests_to_run(
                args.output_file,
                diff_with_last_commit=diff_with_last_commit,
                filters=args.filters,
                json_output_file=args.json_output_file,
            )
Sylvain Gugger's avatar
Sylvain Gugger committed
729
            filter_tests(args.output_file, ["repo_utils"])
730
731
        except Exception as e:
            print(f"\nError when trying to grab the relevant tests: {e}\n\nRunning all tests.")
732
            with open(args.output_file, "w", encoding="utf-8") as f:
733
734
735
736
                if args.filters is None:
                    f.write("./tests/")
                else:
                    f.write(" ".join(args.filters))