dynamic_module_utils.py 20.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
"""Utilities to dynamically load objects from the Hub."""
16
import filecmp
17
18
19
20
21
22
23
24
import importlib
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Dict, Optional, Union

25
26
27
28
29
30
31
from .utils import (
    HF_MODULES_CACHE,
    TRANSFORMERS_DYNAMIC_MODULE_NAME,
    cached_file,
    extract_commit_hash,
    is_offline_mode,
    logging,
32
    try_to_load_from_cache,
33
)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def init_hf_modules():
    """
    Creates the cache directory for modules with an init, and adds it to the Python path.
    """
    # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
    if HF_MODULES_CACHE in sys.path:
        return

    sys.path.append(HF_MODULES_CACHE)
    os.makedirs(HF_MODULES_CACHE, exist_ok=True)
    init_path = Path(HF_MODULES_CACHE) / "__init__.py"
    if not init_path.exists():
        init_path.touch()
52
        importlib.invalidate_caches()
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


def create_dynamic_module(name: Union[str, os.PathLike]):
    """
    Creates a dynamic module in the cache directory for modules.
    """
    init_hf_modules()
    dynamic_module_path = Path(HF_MODULES_CACHE) / name
    # If the parent module does not exist yet, recursively create it.
    if not dynamic_module_path.parent.exists():
        create_dynamic_module(dynamic_module_path.parent)
    os.makedirs(dynamic_module_path, exist_ok=True)
    init_path = dynamic_module_path / "__init__.py"
    if not init_path.exists():
        init_path.touch()
68
        importlib.invalidate_caches()
69
70


71
72
73
74
75
76
77
78
79
80
81
def get_relative_imports(module_file):
    """
    Get the list of modules that are relatively imported in a module file.

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.
    """
    with open(module_file, "r", encoding="utf-8") as f:
        content = f.read()

    # Imports of the form `import .xxx`
82
    relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
83
    # Imports of the form `from .xxx import yyy`
84
    relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # Unique-ify
    return list(set(relative_imports))


def get_relative_import_files(module_file):
    """
    Get the list of all files that are needed for a given module. Note that this function recurses through the relative
    imports (if a imports b and b imports c, it will return module files for b and c).

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.
    """
    no_change = False
    files_to_check = [module_file]
    all_relative_imports = []

    # Let's recurse through all relative imports
    while not no_change:
        new_imports = []
        for f in files_to_check:
            new_imports.extend(get_relative_imports(f))

        module_path = Path(module_file).parent
        new_import_files = [str(module_path / m) for m in new_imports]
        new_import_files = [f for f in new_import_files if f not in all_relative_imports]
        files_to_check = [f"{f}.py" for f in new_import_files]

        no_change = len(new_import_files) == 0
        all_relative_imports.extend(files_to_check)

    return all_relative_imports


118
119
120
121
122
123
124
def check_imports(filename):
    """
    Check if the current Python environment contains all the libraries that are imported in a file.
    """
    with open(filename, "r", encoding="utf-8") as f:
        content = f.read()

125
126
127
    # filter out try/except block so in custom code we can have try/except imports
    content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*:", "", content, flags=re.MULTILINE)

128
    # Imports of the form `import xxx`
129
    imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
130
    # Imports of the form `from xxx import yyy`
131
    imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    # Only keep the top-level module
    imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]

    # Unique-ify and test we got them all
    imports = list(set(imports))
    missing_packages = []
    for imp in imports:
        try:
            importlib.import_module(imp)
        except ImportError:
            missing_packages.append(imp)

    if len(missing_packages) > 0:
        raise ImportError(
            "This modeling file requires the following packages that were not found in your environment: "
            f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
        )

150
    return get_relative_imports(filename)
151

152
153
154
155
156

def get_class_in_module(class_name, module_path):
    """
    Import a module on the cache directory for modules and extract a class from it.
    """
157
158
159
    module_path = module_path.replace(os.path.sep, ".")
    module = importlib.import_module(module_path)
    return getattr(module, class_name)
160
161


162
def get_cached_module_file(
163
164
165
166
167
168
169
170
171
    pretrained_model_name_or_path: Union[str, os.PathLike],
    module_file: str,
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
172
    _commit_hash: Optional[str] = None,
173
174
):
    """
175
176
    Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
    Transformers module.
177
178

    Args:
179
        pretrained_model_name_or_path (`str` or `os.PathLike`):
180
181
            This can be either:

182
            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
Sylvain Gugger's avatar
Sylvain Gugger committed
183
184
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
185
186
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
187

188
        module_file (`str`):
189
            The name of the module file containing the class to look for.
190
        cache_dir (`str` or `os.PathLike`, *optional*):
191
192
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
193
        force_download (`bool`, *optional*, defaults to `False`):
194
195
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
196
        resume_download (`bool`, *optional*, defaults to `False`):
197
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
198
        proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
199
200
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
201
        use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
202
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
203
            when running `huggingface-cli login` (stored in `~/.huggingface`).
204
        revision (`str`, *optional*, defaults to `"main"`):
205
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
206
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
207
            identifier allowed by git.
208
209
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
210

211
    <Tip>
212

213
    Passing `use_auth_token=True` is required when you want to use a private model.
214

215
    </Tip>
216
217

    Returns:
218
219
        `str`: The path to the module inside the cache.
    """
220
221
222
223
224
225
    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True

    # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
226
227
    is_local = os.path.isdir(pretrained_model_name_or_path)
    if is_local:
228
        submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
229
230
    else:
        submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
231
232
233
        cached_module = try_to_load_from_cache(
            pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash
        )
234

235
    new_files = []
236
237
    try:
        # Load from URL or cache if already cached
Sylvain Gugger's avatar
Sylvain Gugger committed
238
239
240
        resolved_module_file = cached_file(
            pretrained_model_name_or_path,
            module_file,
241
242
243
244
245
246
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            local_files_only=local_files_only,
            use_auth_token=use_auth_token,
247
            revision=revision,
248
            _commit_hash=_commit_hash,
249
        )
250
251
        if not is_local and cached_module != resolved_module_file:
            new_files.append(module_file)
252
253
254
255
256
257

    except EnvironmentError:
        logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
        raise

    # Check we have all the requirements in our environment
258
    modules_needed = check_imports(resolved_module_file)
259
260
261
262
263

    # Now we move the module inside our cached dynamic modules.
    full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
    create_dynamic_module(full_submodule)
    submodule_path = Path(HF_MODULES_CACHE) / full_submodule
264
    if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
265
266
267
268
269
270
271
        # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
        # has changed since last copy.
        if not (submodule_path / module_file).exists() or not filecmp.cmp(
            resolved_module_file, str(submodule_path / module_file)
        ):
            shutil.copy(resolved_module_file, submodule_path / module_file)
            importlib.invalidate_caches()
272
273
        for module_needed in modules_needed:
            module_needed = f"{module_needed}.py"
274
275
276
277
278
279
            module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
            if not (submodule_path / module_needed).exists() or not filecmp.cmp(
                module_needed_file, str(submodule_path / module_needed)
            ):
                shutil.copy(module_needed_file, submodule_path / module_needed)
                importlib.invalidate_caches()
280
    else:
281
        # Get the commit hash
282
        commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
283
284
285
286
287
288
289
290
291

        # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
        # benefit of versioning.
        submodule_path = submodule_path / commit_hash
        full_submodule = full_submodule + os.path.sep + commit_hash
        create_dynamic_module(full_submodule)

        if not (submodule_path / module_file).exists():
            shutil.copy(resolved_module_file, submodule_path / module_file)
292
            importlib.invalidate_caches()
293
294
        # Make sure we also have every file with relative
        for module_needed in modules_needed:
295
            if not (submodule_path / f"{module_needed}.py").exists():
296
297
298
299
300
301
302
303
304
305
                get_cached_module_file(
                    pretrained_model_name_or_path,
                    f"{module_needed}.py",
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                    use_auth_token=use_auth_token,
                    revision=revision,
                    local_files_only=local_files_only,
306
                    _commit_hash=commit_hash,
307
                )
308
309
310
311
312
313
314
315
316
317
                new_files.append(f"{module_needed}.py")

    if len(new_files) > 0:
        new_files = "\n".join([f"- {f}" for f in new_files])
        logger.warning(
            f"A new version of the following files was downloaded from {pretrained_model_name_or_path}:\n{new_files}"
            "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
            "versions of the code file, you can pin a revision."
        )

318
319
320
321
    return os.path.join(full_submodule, module_file)


def get_class_from_dynamic_module(
322
    class_reference: str,
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    pretrained_model_name_or_path: Union[str, os.PathLike],
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
    **kwargs,
):
    """
    Extracts a class from a module file, present in the local folder or repository of a model.

    <Tip warning={true}>

    Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
    therefore only be called on trusted repos.

    </Tip>

    Args:
344
345
        class_reference (`str`):
            The full name of the class to load, including its module and optionally its repo.
346
347
348
349
350
351
352
353
354
        pretrained_model_name_or_path (`str` or `os.PathLike`):
            This can be either:

            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.

355
            This is used when `class_reference` does not specify another repo.
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        module_file (`str`):
            The name of the module file containing the class to look for.
        class_name (`str`):
            The name of the class to import in the module.
        cache_dir (`str` or `os.PathLike`, *optional*):
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
        force_download (`bool`, *optional*, defaults to `False`):
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
        resume_download (`bool`, *optional*, defaults to `False`):
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
371
        use_auth_token (`str` or `bool`, *optional*):
372
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
373
            when running `huggingface-cli login` (stored in `~/.huggingface`).
374
        revision (`str`, *optional*, defaults to `"main"`):
375
376
377
378
379
380
381
382
383
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
            identifier allowed by git.
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.

    <Tip>

    Passing `use_auth_token=True` is required when you want to use a private model.
384

385
386
387
388
389
390
391
392
    </Tip>

    Returns:
        `type`: The class, dynamically imported from the module.

    Examples:

    ```python
393
    # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
394
    # module.
395
396
397
398
399
    cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")

    # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
    # module.
    cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
400
    ```"""
401
402
403
404
405
406
407
408
409
    # Catch the name of the repo if it's specified in `class_reference`
    if "--" in class_reference:
        repo_id, class_reference = class_reference.split("--")
        # Invalidate revision since it's not relevant for this repo
        revision = "main"
    else:
        repo_id = pretrained_model_name_or_path
    module_file, class_name = class_reference.split(".")

410
    # And lastly we get the class inside our newly created module
411
    final_module = get_cached_module_file(
412
413
        repo_id,
        module_file + ".py",
414
415
416
417
418
419
420
421
422
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        use_auth_token=use_auth_token,
        revision=revision,
        local_files_only=local_files_only,
    )
    return get_class_in_module(class_name, final_module.replace(".py", ""))
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441


def custom_object_save(obj, folder, config=None):
    """
    Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
    adds the proper fields in a config.

    Args:
        obj (`Any`): The object for which to save the module files.
        folder (`str` or `os.PathLike`): The folder where to save.
        config (`PretrainedConfig` or dictionary, `optional`):
            A config in which to register the auto_map corresponding to this custom object.
    """
    if obj.__module__ == "__main__":
        logger.warning(
            f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
            "this code in a separate module so we can include it in the saved folder and make it easier to share via "
            "the Hub."
        )
442
443

    def _set_auto_map_in_config(_config):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        module_name = obj.__class__.__module__
        last_module = module_name.split(".")[-1]
        full_name = f"{last_module}.{obj.__class__.__name__}"
        # Special handling for tokenizers
        if "Tokenizer" in full_name:
            slow_tokenizer_class = None
            fast_tokenizer_class = None
            if obj.__class__.__name__.endswith("Fast"):
                # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
                fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
                if getattr(obj, "slow_tokenizer_class", None) is not None:
                    slow_tokenizer = getattr(obj, "slow_tokenizer_class")
                    slow_tok_module_name = slow_tokenizer.__module__
                    last_slow_tok_module = slow_tok_module_name.split(".")[-1]
                    slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
            else:
                # Slow tokenizer: no way to have the fast class
                slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"

            full_name = (slow_tokenizer_class, fast_tokenizer_class)

465
466
467
468
469
470
        if isinstance(_config, dict):
            auto_map = _config.get("auto_map", {})
            auto_map[obj._auto_class] = full_name
            _config["auto_map"] = auto_map
        elif getattr(_config, "auto_map", None) is not None:
            _config.auto_map[obj._auto_class] = full_name
471
        else:
472
473
474
475
476
477
478
479
            _config.auto_map = {obj._auto_class: full_name}

    # Add object class to the config auto_map
    if isinstance(config, (list, tuple)):
        for cfg in config:
            _set_auto_map_in_config(cfg)
    elif config is not None:
        _set_auto_map_in_config(config)
480
481
482
483
484
485
486
487
488
489

    # Copy module file to the output folder.
    object_file = sys.modules[obj.__module__].__file__
    dest_file = Path(folder) / (Path(object_file).name)
    shutil.copy(object_file, dest_file)

    # Gather all relative imports recursively and make sure they are copied as well.
    for needed_file in get_relative_import_files(object_file):
        dest_file = Path(folder) / (Path(needed_file).name)
        shutil.copy(needed_file, dest_file)