"vscode:/vscode.git/clone" did not exist on "088ad458885bb41694deb4e52bea914087a64dad"
test_doc_samples.py 6.94 KB
Newer Older
Lysandre's avatar
Lysandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2019-present, the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
import doctest
import logging
Lysandre's avatar
Lysandre committed
17
import os
18
import unittest
19
from glob import glob
20
from pathlib import Path
21
22
from typing import List, Union

23
import transformers
24
from transformers.testing_utils import require_tf, require_torch, slow
Lysandre's avatar
Lysandre committed
25

Lysandre's avatar
Lysandre committed
26

27
logger = logging.getLogger()
Lysandre's avatar
Lysandre committed
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@require_torch
class TestDocLists(unittest.TestCase):
    def test_flash_support_list(self):
        with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
            doctext = f.read()

            doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]
            doctext = doctext.split("You can request to add FlashAttention-2 support")[0]

        patterns = glob("./src/transformers/models/**/modeling_*.py")
        patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
        patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
        patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
        archs_supporting_fa2 = []
        for filename in patterns:
            with open(filename, "r") as f:
                text = f.read()

                if "_supports_flash_attn_2 = True" in text:
                    model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
                    archs_supporting_fa2.append(model_name)

        for arch in archs_supporting_fa2:
            if arch not in doctext:
                raise ValueError(
                    f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."
                )

    def test_sdpa_support_list(self):
        with open("./docs/source/en/perf_infer_gpu_one.md", "r") as f:
            doctext = f.read()

            doctext = doctext.split(
                "For now, Transformers supports inference and training through SDPA for the following architectures:"
            )[1]
            doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]

        patterns = glob("./src/transformers/models/**/modeling_*.py")
        patterns_tf = glob("./src/transformers/models/**/modeling_tf_*.py")
        patterns_flax = glob("./src/transformers/models/**/modeling_flax_*.py")
        patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))
        archs_supporting_sdpa = []
        for filename in patterns:
            with open(filename, "r") as f:
                text = f.read()

                if "_supports_sdpa = True" in text:
                    model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")
                    archs_supporting_sdpa.append(model_name)

        for arch in archs_supporting_sdpa:
            if arch not in doctext:
                raise ValueError(
                    f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
                )


Lysandre Debut's avatar
Lysandre Debut committed
87
@unittest.skip("Temporarily disable the doc tests.")
Lysandre's avatar
Lysandre committed
88
@require_torch
Lysandre's avatar
Lysandre committed
89
90
@require_tf
@slow
Lysandre's avatar
Lysandre committed
91
class TestCodeExamples(unittest.TestCase):
92
    def analyze_directory(
93
94
95
        self,
        directory: Path,
        identifier: Union[str, None] = None,
96
97
        ignore_files: Union[List[str], None] = None,
        n_identifier: Union[str, List[str], None] = None,
98
        only_modules: bool = True,
99
    ):
100
101
102
103
104
        """
        Runs through the specific directory, looking for the files identified with `identifier`. Executes
        the doctests in those files

        Args:
Stas Bekman's avatar
Stas Bekman committed
105
106
107
108
109
            directory (`Path`): Directory containing the files
            identifier (`str`): Will parse files containing this
            ignore_files (`List[str]`): List of files to skip
            n_identifier (`str` or `List[str]`): Will not parse files containing this/these identifiers.
            only_modules (`bool`): Whether to only analyze modules
110
        """
111
        files = [file for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
Lysandre's avatar
Lysandre committed
112

113
114
115
        if identifier is not None:
            files = [file for file in files if identifier in file]

116
117
118
119
120
121
122
        if n_identifier is not None:
            if isinstance(n_identifier, List):
                for n_ in n_identifier:
                    files = [file for file in files if n_ not in file]
            else:
                files = [file for file in files if n_identifier not in file]

123
        ignore_files = ignore_files or []
124
125
        ignore_files.append("__init__.py")
        files = [file for file in files if file not in ignore_files]
126
127
128

        for file in files:
            # Open all files
129
130
131
            print("Testing", file)

            if only_modules:
132
                module_identifier = file.split(".")[0]
133
134
135
136
137
138
139
140
141
142
                try:
                    module_identifier = getattr(transformers, module_identifier)
                    suite = doctest.DocTestSuite(module_identifier)
                    result = unittest.TextTestRunner().run(suite)
                    self.assertIs(len(result.failures), 0)
                except AttributeError:
                    logger.info(f"{module_identifier} is not a module.")
            else:
                result = doctest.testfile(str(".." / directory / file), optionflags=doctest.ELLIPSIS)
                self.assertIs(result.failed, 0)
Lysandre's avatar
Lysandre committed
143
144

    def test_modeling_examples(self):
145
        transformers_directory = Path("src/transformers")
146
        files = "modeling"
Lysandre's avatar
Lysandre committed
147
        ignore_files = [
148
149
            "modeling_ctrl.py",
            "modeling_tf_ctrl.py",
Lysandre's avatar
Lysandre committed
150
        ]
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        self.analyze_directory(transformers_directory, identifier=files, ignore_files=ignore_files)

    def test_tokenization_examples(self):
        transformers_directory = Path("src/transformers")
        files = "tokenization"
        self.analyze_directory(transformers_directory, identifier=files)

    def test_configuration_examples(self):
        transformers_directory = Path("src/transformers")
        files = "configuration"
        self.analyze_directory(transformers_directory, identifier=files)

    def test_remaining_examples(self):
        transformers_directory = Path("src/transformers")
        n_identifiers = ["configuration", "modeling", "tokenization"]
        self.analyze_directory(transformers_directory, n_identifier=n_identifiers)

    def test_doc_sources(self):
        doc_source_directory = Path("docs/source")
        ignore_files = ["favicon.ico"]
        self.analyze_directory(doc_source_directory, ignore_files=ignore_files, only_modules=False)