dynamic_module_utils.py 21.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
"""Utilities to dynamically load objects from the Hub."""
16
import filecmp
17
18
19
20
21
22
23
24
import importlib
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Dict, Optional, Union

25
26
27
28
29
30
31
from .utils import (
    HF_MODULES_CACHE,
    TRANSFORMERS_DYNAMIC_MODULE_NAME,
    cached_file,
    extract_commit_hash,
    is_offline_mode,
    logging,
32
    try_to_load_from_cache,
33
)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def init_hf_modules():
    """
    Creates the cache directory for modules with an init, and adds it to the Python path.
    """
    # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
    if HF_MODULES_CACHE in sys.path:
        return

    sys.path.append(HF_MODULES_CACHE)
    os.makedirs(HF_MODULES_CACHE, exist_ok=True)
    init_path = Path(HF_MODULES_CACHE) / "__init__.py"
    if not init_path.exists():
        init_path.touch()
52
        importlib.invalidate_caches()
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


def create_dynamic_module(name: Union[str, os.PathLike]):
    """
    Creates a dynamic module in the cache directory for modules.
    """
    init_hf_modules()
    dynamic_module_path = Path(HF_MODULES_CACHE) / name
    # If the parent module does not exist yet, recursively create it.
    if not dynamic_module_path.parent.exists():
        create_dynamic_module(dynamic_module_path.parent)
    os.makedirs(dynamic_module_path, exist_ok=True)
    init_path = dynamic_module_path / "__init__.py"
    if not init_path.exists():
        init_path.touch()
68
        importlib.invalidate_caches()
69
70


71
72
73
74
75
76
77
78
79
80
81
def get_relative_imports(module_file):
    """
    Get the list of modules that are relatively imported in a module file.

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.
    """
    with open(module_file, "r", encoding="utf-8") as f:
        content = f.read()

    # Imports of the form `import .xxx`
82
    relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
83
    # Imports of the form `from .xxx import yyy`
84
    relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    # Unique-ify
    return list(set(relative_imports))


def get_relative_import_files(module_file):
    """
    Get the list of all files that are needed for a given module. Note that this function recurses through the relative
    imports (if a imports b and b imports c, it will return module files for b and c).

    Args:
        module_file (`str` or `os.PathLike`): The module file to inspect.
    """
    no_change = False
    files_to_check = [module_file]
    all_relative_imports = []

    # Let's recurse through all relative imports
    while not no_change:
        new_imports = []
        for f in files_to_check:
            new_imports.extend(get_relative_imports(f))

        module_path = Path(module_file).parent
        new_import_files = [str(module_path / m) for m in new_imports]
        new_import_files = [f for f in new_import_files if f not in all_relative_imports]
        files_to_check = [f"{f}.py" for f in new_import_files]

        no_change = len(new_import_files) == 0
        all_relative_imports.extend(files_to_check)

    return all_relative_imports


Sylvain Gugger's avatar
Sylvain Gugger committed
118
def get_imports(filename):
119
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
120
    Extracts all the libraries that are imported in a file.
121
122
123
124
    """
    with open(filename, "r", encoding="utf-8") as f:
        content = f.read()

125
126
127
    # filter out try/except block so in custom code we can have try/except imports
    content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*:", "", content, flags=re.MULTILINE)

128
    # Imports of the form `import xxx`
129
    imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
130
    # Imports of the form `from xxx import yyy`
131
    imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
132
133
    # Only keep the top-level module
    imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
Sylvain Gugger's avatar
Sylvain Gugger committed
134
135
    return list(set(imports))

136

Sylvain Gugger's avatar
Sylvain Gugger committed
137
138
139
140
141
def check_imports(filename):
    """
    Check if the current Python environment contains all the libraries that are imported in a file.
    """
    imports = get_imports(filename)
142
143
144
145
146
147
148
149
150
151
152
153
154
    missing_packages = []
    for imp in imports:
        try:
            importlib.import_module(imp)
        except ImportError:
            missing_packages.append(imp)

    if len(missing_packages) > 0:
        raise ImportError(
            "This modeling file requires the following packages that were not found in your environment: "
            f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
        )

155
    return get_relative_imports(filename)
156

157
158
159
160
161

def get_class_in_module(class_name, module_path):
    """
    Import a module on the cache directory for modules and extract a class from it.
    """
162
163
164
    module_path = module_path.replace(os.path.sep, ".")
    module = importlib.import_module(module_path)
    return getattr(module, class_name)
165
166


167
def get_cached_module_file(
168
169
170
171
172
173
174
175
176
    pretrained_model_name_or_path: Union[str, os.PathLike],
    module_file: str,
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
Sylvain Gugger's avatar
Sylvain Gugger committed
177
    repo_type: Optional[str] = None,
178
    _commit_hash: Optional[str] = None,
179
180
):
    """
181
182
    Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
    Transformers module.
183
184

    Args:
185
        pretrained_model_name_or_path (`str` or `os.PathLike`):
186
187
            This can be either:

188
            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
Sylvain Gugger's avatar
Sylvain Gugger committed
189
190
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
191
192
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
193

194
        module_file (`str`):
195
            The name of the module file containing the class to look for.
196
        cache_dir (`str` or `os.PathLike`, *optional*):
197
198
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
199
        force_download (`bool`, *optional*, defaults to `False`):
200
201
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
202
        resume_download (`bool`, *optional*, defaults to `False`):
203
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
204
        proxies (`Dict[str, str]`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
205
206
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
207
        use_auth_token (`str` or *bool*, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
208
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
209
            when running `huggingface-cli login` (stored in `~/.huggingface`).
210
        revision (`str`, *optional*, defaults to `"main"`):
211
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
212
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
213
            identifier allowed by git.
214
215
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
        repo_type (`str`, *optional*):
            Specify the repo type (useful when downloading from a space for instance).
218

219
    <Tip>
220

221
    Passing `use_auth_token=True` is required when you want to use a private model.
222

223
    </Tip>
224
225

    Returns:
226
227
        `str`: The path to the module inside the cache.
    """
228
229
230
231
232
233
    if is_offline_mode() and not local_files_only:
        logger.info("Offline mode: forcing local_files_only=True")
        local_files_only = True

    # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
    pretrained_model_name_or_path = str(pretrained_model_name_or_path)
234
235
    is_local = os.path.isdir(pretrained_model_name_or_path)
    if is_local:
236
        submodule = pretrained_model_name_or_path.split(os.path.sep)[-1]
237
238
    else:
        submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
239
        cached_module = try_to_load_from_cache(
Sylvain Gugger's avatar
Sylvain Gugger committed
240
            pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
241
        )
242

243
    new_files = []
244
245
    try:
        # Load from URL or cache if already cached
Sylvain Gugger's avatar
Sylvain Gugger committed
246
247
248
        resolved_module_file = cached_file(
            pretrained_model_name_or_path,
            module_file,
249
250
251
252
253
254
            cache_dir=cache_dir,
            force_download=force_download,
            proxies=proxies,
            resume_download=resume_download,
            local_files_only=local_files_only,
            use_auth_token=use_auth_token,
255
            revision=revision,
Sylvain Gugger's avatar
Sylvain Gugger committed
256
            repo_type=repo_type,
257
            _commit_hash=_commit_hash,
258
        )
259
260
        if not is_local and cached_module != resolved_module_file:
            new_files.append(module_file)
261
262
263
264
265
266

    except EnvironmentError:
        logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
        raise

    # Check we have all the requirements in our environment
267
    modules_needed = check_imports(resolved_module_file)
268
269
270
271
272

    # Now we move the module inside our cached dynamic modules.
    full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
    create_dynamic_module(full_submodule)
    submodule_path = Path(HF_MODULES_CACHE) / full_submodule
273
    if submodule == pretrained_model_name_or_path.split(os.path.sep)[-1]:
274
275
276
277
278
279
280
        # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
        # has changed since last copy.
        if not (submodule_path / module_file).exists() or not filecmp.cmp(
            resolved_module_file, str(submodule_path / module_file)
        ):
            shutil.copy(resolved_module_file, submodule_path / module_file)
            importlib.invalidate_caches()
281
282
        for module_needed in modules_needed:
            module_needed = f"{module_needed}.py"
283
284
285
286
287
288
            module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
            if not (submodule_path / module_needed).exists() or not filecmp.cmp(
                module_needed_file, str(submodule_path / module_needed)
            ):
                shutil.copy(module_needed_file, submodule_path / module_needed)
                importlib.invalidate_caches()
289
    else:
290
        # Get the commit hash
291
        commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
292
293
294
295
296
297
298
299
300

        # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
        # benefit of versioning.
        submodule_path = submodule_path / commit_hash
        full_submodule = full_submodule + os.path.sep + commit_hash
        create_dynamic_module(full_submodule)

        if not (submodule_path / module_file).exists():
            shutil.copy(resolved_module_file, submodule_path / module_file)
301
            importlib.invalidate_caches()
302
303
        # Make sure we also have every file with relative
        for module_needed in modules_needed:
304
            if not (submodule_path / f"{module_needed}.py").exists():
305
306
307
308
309
310
311
312
313
314
                get_cached_module_file(
                    pretrained_model_name_or_path,
                    f"{module_needed}.py",
                    cache_dir=cache_dir,
                    force_download=force_download,
                    resume_download=resume_download,
                    proxies=proxies,
                    use_auth_token=use_auth_token,
                    revision=revision,
                    local_files_only=local_files_only,
315
                    _commit_hash=commit_hash,
316
                )
317
318
319
320
                new_files.append(f"{module_needed}.py")

    if len(new_files) > 0:
        new_files = "\n".join([f"- {f}" for f in new_files])
Sylvain Gugger's avatar
Sylvain Gugger committed
321
322
        repo_type_str = "" if repo_type is None else f"{repo_type}/"
        url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
323
        logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
324
            f"A new version of the following files was downloaded from {url}:\n{new_files}"
325
326
327
328
            "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
            "versions of the code file, you can pin a revision."
        )

329
330
331
332
    return os.path.join(full_submodule, module_file)


def get_class_from_dynamic_module(
333
    class_reference: str,
334
335
336
337
338
339
340
341
    pretrained_model_name_or_path: Union[str, os.PathLike],
    cache_dir: Optional[Union[str, os.PathLike]] = None,
    force_download: bool = False,
    resume_download: bool = False,
    proxies: Optional[Dict[str, str]] = None,
    use_auth_token: Optional[Union[bool, str]] = None,
    revision: Optional[str] = None,
    local_files_only: bool = False,
Sylvain Gugger's avatar
Sylvain Gugger committed
342
    repo_type: Optional[str] = None,
343
344
345
346
347
348
349
350
351
352
353
354
355
    **kwargs,
):
    """
    Extracts a class from a module file, present in the local folder or repository of a model.

    <Tip warning={true}>

    Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
    therefore only be called on trusted repos.

    </Tip>

    Args:
356
357
        class_reference (`str`):
            The full name of the class to load, including its module and optionally its repo.
358
359
360
361
362
363
364
365
366
        pretrained_model_name_or_path (`str` or `os.PathLike`):
            This can be either:

            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
              huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
              under a user or organization name, like `dbmdz/bert-base-german-cased`.
            - a path to a *directory* containing a configuration file saved using the
              [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.

367
            This is used when `class_reference` does not specify another repo.
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        module_file (`str`):
            The name of the module file containing the class to look for.
        class_name (`str`):
            The name of the class to import in the module.
        cache_dir (`str` or `os.PathLike`, *optional*):
            Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
            cache should not be used.
        force_download (`bool`, *optional*, defaults to `False`):
            Whether or not to force to (re-)download the configuration files and override the cached versions if they
            exist.
        resume_download (`bool`, *optional*, defaults to `False`):
            Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
        proxies (`Dict[str, str]`, *optional*):
            A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
            'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
383
        use_auth_token (`str` or `bool`, *optional*):
384
            The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
385
            when running `huggingface-cli login` (stored in `~/.huggingface`).
386
        revision (`str`, *optional*, defaults to `"main"`):
387
388
389
390
391
            The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
            git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
            identifier allowed by git.
        local_files_only (`bool`, *optional*, defaults to `False`):
            If `True`, will only try to load the tokenizer configuration from local files.
Sylvain Gugger's avatar
Sylvain Gugger committed
392
393
        repo_type (`str`, *optional*):
            Specify the repo type (useful when downloading from a space for instance).
394
395
396
397

    <Tip>

    Passing `use_auth_token=True` is required when you want to use a private model.
398

399
400
401
402
403
404
405
406
    </Tip>

    Returns:
        `type`: The class, dynamically imported from the module.

    Examples:

    ```python
407
    # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
408
    # module.
409
410
411
412
413
    cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")

    # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
    # module.
    cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
414
    ```"""
415
416
417
418
419
420
421
422
423
    # Catch the name of the repo if it's specified in `class_reference`
    if "--" in class_reference:
        repo_id, class_reference = class_reference.split("--")
        # Invalidate revision since it's not relevant for this repo
        revision = "main"
    else:
        repo_id = pretrained_model_name_or_path
    module_file, class_name = class_reference.split(".")

424
    # And lastly we get the class inside our newly created module
425
    final_module = get_cached_module_file(
426
427
        repo_id,
        module_file + ".py",
428
429
430
431
432
433
434
        cache_dir=cache_dir,
        force_download=force_download,
        resume_download=resume_download,
        proxies=proxies,
        use_auth_token=use_auth_token,
        revision=revision,
        local_files_only=local_files_only,
Sylvain Gugger's avatar
Sylvain Gugger committed
435
        repo_type=repo_type,
436
437
    )
    return get_class_in_module(class_name, final_module.replace(".py", ""))
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456


def custom_object_save(obj, folder, config=None):
    """
    Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
    adds the proper fields in a config.

    Args:
        obj (`Any`): The object for which to save the module files.
        folder (`str` or `os.PathLike`): The folder where to save.
        config (`PretrainedConfig` or dictionary, `optional`):
            A config in which to register the auto_map corresponding to this custom object.
    """
    if obj.__module__ == "__main__":
        logger.warning(
            f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
            "this code in a separate module so we can include it in the saved folder and make it easier to share via "
            "the Hub."
        )
Sylvain Gugger's avatar
Sylvain Gugger committed
457
        return
458
459

    def _set_auto_map_in_config(_config):
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        module_name = obj.__class__.__module__
        last_module = module_name.split(".")[-1]
        full_name = f"{last_module}.{obj.__class__.__name__}"
        # Special handling for tokenizers
        if "Tokenizer" in full_name:
            slow_tokenizer_class = None
            fast_tokenizer_class = None
            if obj.__class__.__name__.endswith("Fast"):
                # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
                fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
                if getattr(obj, "slow_tokenizer_class", None) is not None:
                    slow_tokenizer = getattr(obj, "slow_tokenizer_class")
                    slow_tok_module_name = slow_tokenizer.__module__
                    last_slow_tok_module = slow_tok_module_name.split(".")[-1]
                    slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
            else:
                # Slow tokenizer: no way to have the fast class
                slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"

            full_name = (slow_tokenizer_class, fast_tokenizer_class)

481
482
483
484
485
486
        if isinstance(_config, dict):
            auto_map = _config.get("auto_map", {})
            auto_map[obj._auto_class] = full_name
            _config["auto_map"] = auto_map
        elif getattr(_config, "auto_map", None) is not None:
            _config.auto_map[obj._auto_class] = full_name
487
        else:
488
489
490
491
492
493
494
495
            _config.auto_map = {obj._auto_class: full_name}

    # Add object class to the config auto_map
    if isinstance(config, (list, tuple)):
        for cfg in config:
            _set_auto_map_in_config(cfg)
    elif config is not None:
        _set_auto_map_in_config(config)
496

Sylvain Gugger's avatar
Sylvain Gugger committed
497
    result = []
498
499
500
501
    # Copy module file to the output folder.
    object_file = sys.modules[obj.__module__].__file__
    dest_file = Path(folder) / (Path(object_file).name)
    shutil.copy(object_file, dest_file)
Sylvain Gugger's avatar
Sylvain Gugger committed
502
    result.append(dest_file)
503
504
505
506
507

    # Gather all relative imports recursively and make sure they are copied as well.
    for needed_file in get_relative_import_files(object_file):
        dest_file = Path(folder) / (Path(needed_file).name)
        shutil.copy(needed_file, dest_file)
Sylvain Gugger's avatar
Sylvain Gugger committed
508
509
510
        result.append(dest_file)

    return result