pipeline_utils.py 95.2 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2024 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
import fnmatch
17
18
19
import importlib
import inspect
import os
20
import re
21
import sys
22
23
from dataclasses import dataclass
from pathlib import Path
24
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
25
26

import numpy as np
Anh71me's avatar
Anh71me committed
27
import PIL.Image
28
import requests
29
import torch
30
31
32
33
34
35
36
from huggingface_hub import (
    ModelCard,
    create_repo,
    hf_hub_download,
    model_info,
    snapshot_download,
)
37
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
38
from packaging import version
39
from requests.exceptions import HTTPError
40
41
from tqdm.auto import tqdm

42
from .. import __version__
43
from ..configuration_utils import ConfigMixin
44
45
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
46
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
47
48
49
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
    CONFIG_NAME,
50
    DEPRECATED_REVISION_ARGS,
51
    BaseOutput,
52
    PushToHubMixin,
53
54
    deprecate,
    is_accelerate_available,
55
    is_accelerate_version,
Mengqing Cao's avatar
Mengqing Cao committed
56
    is_torch_npu_available,
57
58
    is_torch_version,
    logging,
Patrick von Platen's avatar
Patrick von Platen committed
59
    numpy_to_pil,
60
)
61
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
Dhruv Nair's avatar
Dhruv Nair committed
62
from ..utils.torch_utils import is_compiled_module
Mengqing Cao's avatar
Mengqing Cao committed
63
64
65
66
67
68


if is_torch_npu_available():
    import torch_npu  # noqa: F401


69
70
71
72
73
74
from .pipeline_loading_utils import (
    ALL_IMPORTABLE_CLASSES,
    CONNECTED_PIPES_KEYS,
    CUSTOM_PIPELINE_FILE_NAME,
    LOADABLE_CLASSES,
    _fetch_class_library_tuple,
75
    _get_custom_pipeline_class,
76
    _get_final_device_map,
77
78
79
80
81
82
83
84
    _get_pipeline_class,
    _unwrap_model,
    is_safetensors_compatible,
    load_sub_model,
    maybe_raise_or_warn,
    variant_compatible_siblings,
    warn_deprecated_model_variant,
)
85
86


87
88
89
90
if is_accelerate_available():
    import accelerate


91
92
93
LIBRARIES = []
for library in LOADABLE_CLASSES:
    LIBRARIES.append(library)
94

95
96
SUPPORTED_DEVICE_MAP = ["balanced"]

97
98
99
100
101
102
103
104
105
106
logger = logging.get_logger(__name__)


@dataclass
class ImagePipelineOutput(BaseOutput):
    """
    Output class for image pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
Steven Liu's avatar
Steven Liu committed
107
108
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
109
110
111
112
113
114
115
116
117
118
119
120
    """

    images: Union[List[PIL.Image.Image], np.ndarray]


@dataclass
class AudioPipelineOutput(BaseOutput):
    """
    Output class for audio pipelines.

    Args:
        audios (`np.ndarray`)
Steven Liu's avatar
Steven Liu committed
121
            List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`.
122
123
124
125
126
    """

    audios: np.ndarray


127
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
128
    r"""
Steven Liu's avatar
Steven Liu committed
129
    Base class for all pipelines.
130

Steven Liu's avatar
Steven Liu committed
131
132
    [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and
    provides methods for loading, downloading and saving models. It also includes methods to:
133
134

        - move all PyTorch modules to the device of your choice
135
        - enable/disable the progress bar for the denoising iteration
136
137
138

    Class attributes:

Steven Liu's avatar
Steven Liu committed
139
140
        - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
          diffusion pipeline's components.
141
        - **_optional_components** (`List[str]`) -- List of all optional components that don't have to be passed to the
Steven Liu's avatar
Steven Liu committed
142
          pipeline to function (should be overridden by subclasses).
143
    """
144

145
    config_name = "model_index.json"
146
    model_cpu_offload_seq = None
147
    hf_device_map = None
148
    _optional_components = []
149
    _exclude_from_cpu_offload = []
150
    _load_connected_pipes = False
151
    _is_onnx = False
152
153
154
155

    def register_modules(self, **kwargs):
        for name, module in kwargs.items():
            # retrieve library
156
            if module is None or isinstance(module, (tuple, list)) and module[0] is None:
157
158
                register_dict = {name: (None, None)}
            else:
159
                library, class_name = _fetch_class_library_tuple(module)
160
161
162
163
164
165
166
167
                register_dict = {name: (library, class_name)}

            # save model index config
            self.register_to_config(**register_dict)

            # set models
            setattr(self, name, module)

168
    def __setattr__(self, name: str, value: Any):
169
        if name in self.__dict__ and hasattr(self.config, name):
170
171
            # We need to overwrite the config if name exists in config
            if isinstance(getattr(self.config, name), (tuple, list)):
172
                if value is not None and self.config[name][0] is not None:
173
                    class_library_tuple = _fetch_class_library_tuple(value)
174
175
176
177
178
179
180
181
182
                else:
                    class_library_tuple = (None, None)

                self.register_to_config(**{name: class_library_tuple})
            else:
                self.register_to_config(**{name: value})

        super().__setattr__(name, value)

183
184
185
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
186
        safe_serialization: bool = True,
187
        variant: Optional[str] = None,
188
189
        push_to_hub: bool = False,
        **kwargs,
190
191
    ):
        """
Steven Liu's avatar
Steven Liu committed
192
193
194
        Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
        class implements both a save and loading method. The pipeline is easily reloaded using the
        [`~DiffusionPipeline.from_pretrained`] class method.
195
196
197

        Arguments:
            save_directory (`str` or `os.PathLike`):
Steven Liu's avatar
Steven Liu committed
198
                Directory to save a pipeline to. Will be created if it doesn't exist.
199
            safe_serialization (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
200
                Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
201
            variant (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
202
                If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
203
204
205
206
207
208
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
209
210
        """
        model_index_dict = dict(self.config)
211
212
        model_index_dict.pop("_class_name", None)
        model_index_dict.pop("_diffusers_version", None)
213
        model_index_dict.pop("_module", None)
214
        model_index_dict.pop("_name_or_path", None)
215

216
217
218
219
220
221
222
223
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
            private = kwargs.pop("private", False)
            create_pr = kwargs.pop("create_pr", False)
            token = kwargs.pop("token", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

224
225
226
227
228
229
230
231
232
233
234
235
236
237
        expected_modules, optional_kwargs = self._get_signature_keys(self)

        def is_saveable_module(name, value):
            if name not in expected_modules:
                return False
            if name in self._optional_components and value[0] is None:
                return False
            return True

        model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
        for pipeline_component_name in model_index_dict.keys():
            sub_model = getattr(self, pipeline_component_name)
            model_cls = sub_model.__class__

238
239
240
            # Dynamo wraps the original model in a private class.
            # I didn't find a public API to get the original class.
            if is_compiled_module(sub_model):
241
                sub_model = _unwrap_model(sub_model)
242
243
                model_cls = sub_model.__class__

244
245
246
            save_method_name = None
            # search for the model's base class in LOADABLE_CLASSES
            for library_name, library_classes in LOADABLE_CLASSES.items():
247
248
249
250
251
252
253
                if library_name in sys.modules:
                    library = importlib.import_module(library_name)
                else:
                    logger.info(
                        f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
                    )

254
255
256
257
258
259
260
261
262
                for base_class, save_load_methods in library_classes.items():
                    class_candidate = getattr(library, base_class, None)
                    if class_candidate is not None and issubclass(model_cls, class_candidate):
                        # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
                        save_method_name = save_load_methods[0]
                        break
                if save_method_name is not None:
                    break

263
            if save_method_name is None:
264
265
266
                logger.warning(
                    f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved."
                )
267
268
269
270
                # make sure that unsaveable components are not tried to be loaded afterward
                self.register_to_config(**{pipeline_component_name: (None, None)})
                continue

271
272
273
274
275
            save_method = getattr(sub_model, save_method_name)

            # Call the save method with the argument safe_serialization only if it's supported
            save_method_signature = inspect.signature(save_method)
            save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
276
277
278
            save_method_accept_variant = "variant" in save_method_signature.parameters

            save_kwargs = {}
279
            if save_method_accept_safe:
280
281
282
283
284
                save_kwargs["safe_serialization"] = safe_serialization
            if save_method_accept_variant:
                save_kwargs["variant"] = variant

            save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
285

286
287
288
        # finally save the config
        self.save_config(save_directory)

289
        if push_to_hub:
290
291
292
293
294
            # Create a new empty model card and eventually tag it
            model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
            model_card = populate_model_card(model_card)
            model_card.save(os.path.join(save_directory, "README.md"))

295
296
297
298
299
300
301
302
            self._upload_folder(
                save_directory,
                repo_id,
                token=token,
                commit_message=commit_message,
                create_pr=create_pr,
            )

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    def to(self, *args, **kwargs):
        r"""
        Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
        arguments of `self.to(*args, **kwargs).`

        <Tip>

            If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
            the returned pipeline is a copy of self with the desired torch.dtype and torch.device.

        </Tip>


        Here are the ways to call `to`:

        - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
          [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
        - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
          [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
        - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
          specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
          [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)

        Arguments:
            dtype (`torch.dtype`, *optional*):
                Returns a pipeline with the specified
                [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
            device (`torch.Device`, *optional*):
                Returns a pipeline with the specified
                [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
            silence_dtype_warnings (`str`, *optional*, defaults to `False`):
                Whether to omit warnings if the target `dtype` is not compatible with the target `device`.

        Returns:
            [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
        """
339
340
        dtype = kwargs.pop("dtype", None)
        device = kwargs.pop("device", None)
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)

        dtype_arg = None
        device_arg = None
        if len(args) == 1:
            if isinstance(args[0], torch.dtype):
                dtype_arg = args[0]
            else:
                device_arg = torch.device(args[0]) if args[0] is not None else None
        elif len(args) == 2:
            if isinstance(args[0], torch.dtype):
                raise ValueError(
                    "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
                )
            device_arg = torch.device(args[0]) if args[0] is not None else None
            dtype_arg = args[1]
        elif len(args) > 2:
            raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")

        if dtype is not None and dtype_arg is not None:
            raise ValueError(
                "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
            )

        dtype = dtype or dtype_arg

        if device is not None and device_arg is not None:
            raise ValueError(
                "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
            )

        device = device or device_arg
373

374
375
376
377
378
        # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
        def module_is_sequentially_offloaded(module):
            if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
                return False

379
380
381
382
383
            return hasattr(module, "_hf_hook") and (
                isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
                or hasattr(module._hf_hook, "hooks")
                and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
            )
384
385
386
387
388
389
390
391
392
393
394

        def module_is_offloaded(module):
            if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
                return False

            return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

        # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
        pipeline_is_sequentially_offloaded = any(
            module_is_sequentially_offloaded(module) for _, module in self.components.items()
        )
395
        if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
396
397
398
399
            raise ValueError(
                "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
            )

400
401
402
403
404
405
        is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
        if is_pipeline_device_mapped:
            raise ValueError(
                "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
            )

406
407
        # Display a warning in this case (the operation succeeds but the benefits are lost)
        pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
408
        if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
409
410
411
412
            logger.warning(
                f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
            )

413
        module_names, _ = self._get_signature_keys(self)
414
415
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]
416

417
        is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
418
        for module in modules:
Patrick von Platen's avatar
Patrick von Platen committed
419
420
            is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit

421
            if is_loaded_in_8bit and dtype is not None:
Patrick von Platen's avatar
Patrick von Platen committed
422
                logger.warning(
423
                    f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
Patrick von Platen's avatar
Patrick von Platen committed
424
425
                )

426
            if is_loaded_in_8bit and device is not None:
Patrick von Platen's avatar
Patrick von Platen committed
427
                logger.warning(
428
                    f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
Patrick von Platen's avatar
Patrick von Platen committed
429
430
                )
            else:
431
                module.to(device, dtype)
Patrick von Platen's avatar
Patrick von Platen committed
432

433
434
            if (
                module.dtype == torch.float16
435
                and str(device) in ["cpu"]
436
437
438
439
                and not silence_dtype_warnings
                and not is_offloaded
            ):
                logger.warning(
440
                    "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
441
442
443
444
445
                    " is not recommended to move them to `cpu` as running them will fail. Please make"
                    " sure to use an accelerator to run the pipeline in inference, due to the lack of"
                    " support for`float16` operations on this device in PyTorch. Please, remove the"
                    " `torch_dtype=torch.float16` argument, or use another device for inference."
                )
446
447
448
449
450
451
452
453
        return self

    @property
    def device(self) -> torch.device:
        r"""
        Returns:
            `torch.device`: The torch device on which the pipeline is located.
        """
454
        module_names, _ = self._get_signature_keys(self)
455
456
457
458
459
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]

        for module in modules:
            return module.device
460

461
462
        return torch.device("cpu")

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    @property
    def dtype(self) -> torch.dtype:
        r"""
        Returns:
            `torch.dtype`: The torch dtype on which the pipeline is located.
        """
        module_names, _ = self._get_signature_keys(self)
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]

        for module in modules:
            return module.dtype

        return torch.float32

478
    @classmethod
479
    @validate_hf_hub_args
480
481
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
        r"""
Steven Liu's avatar
Steven Liu committed
482
        Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
483

Steven Liu's avatar
Steven Liu committed
484
        The pipeline is set in evaluation mode (`model.eval()`) by default.
485

Steven Liu's avatar
Steven Liu committed
486
        If you get the error message below, you need to finetune the weights for your downstream task:
487

Steven Liu's avatar
Steven Liu committed
488
489
490
491
492
        ```
        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
        ```
493
494
495
496
497

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

Steven Liu's avatar
Steven Liu committed
498
499
500
501
502
                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
                      hosted on the Hub.
                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
                      saved using
                    [`~DiffusionPipeline.save_pretrained`].
503
            torch_dtype (`str` or `torch.dtype`, *optional*):
Steven Liu's avatar
Steven Liu committed
504
505
                Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
                dtype is automatically derived from the model's weights.
506
507
508
509
            custom_pipeline (`str`, *optional*):

                <Tip warning={true}>

Steven Liu's avatar
Steven Liu committed
510
                🧪 This is an experimental feature and may change in the future.
511
512
513
514
515

                </Tip>

                Can be either:

Steven Liu's avatar
Steven Liu committed
516
517
518
                    - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom
                      pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines
                      the custom pipeline.
519
                    - A string, the *file name* of a community pipeline hosted on GitHub under
Steven Liu's avatar
Steven Liu committed
520
521
522
523
524
525
                      [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
                      names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
                      instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
                      current main branch of GitHub.
                    - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory
                      must contain a file called `pipeline.py` that defines the custom pipeline.
526
527
528
529
530
531
532

                For more information on how to load and create custom pipelines, please have a look at [Loading and
                Adding Custom
                Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
533
            cache_dir (`Union[str, os.PathLike]`, *optional*):
Steven Liu's avatar
Steven Liu committed
534
535
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
536

537
            proxies (`Dict[str, str]`, *optional*):
Steven Liu's avatar
Steven Liu committed
538
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
539
540
541
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
Steven Liu's avatar
Steven Liu committed
542
543
544
            local_files_only (`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
545
            token (`str` or *bool*, *optional*):
Steven Liu's avatar
Steven Liu committed
546
547
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
548
            revision (`str`, *optional*, defaults to `"main"`):
Steven Liu's avatar
Steven Liu committed
549
550
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
551
            custom_revision (`str`, *optional*):
552
                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
553
554
                `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
                version.
555
            mirror (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
556
557
558
                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
                information.
559
            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
Steven Liu's avatar
Steven Liu committed
560
561
                A map that specifies where each submodule should go. It doesn’t need to be defined for each
                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
562
563
                same device.

Steven Liu's avatar
Steven Liu committed
564
                Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
565
566
                more information about each option see [designing a device
                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
567
            max_memory (`Dict`, *optional*):
Steven Liu's avatar
Steven Liu committed
568
569
                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
                each GPU and the available CPU RAM if unset.
570
            offload_folder (`str` or `os.PathLike`, *optional*):
Steven Liu's avatar
Steven Liu committed
571
                The path to offload weights if device_map contains the value `"disk"`.
572
            offload_state_dict (`bool`, *optional*):
Steven Liu's avatar
Steven Liu committed
573
574
575
                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
                when there is some disk offload.
576
            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Steven Liu's avatar
Steven Liu committed
577
578
579
580
                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
                argument to `True` will raise an error.
581
            use_safetensors (`bool`, *optional*, defaults to `None`):
Steven Liu's avatar
Steven Liu committed
582
583
584
                If set to `None`, the safetensors weights are downloaded if they're available **and** if the
                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
                weights. If set to `False`, safetensors weights are not loaded.
585
586
587
588
589
            use_onnx (`bool`, *optional*, defaults to `None`):
                If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
                will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
                `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
                with `.onnx` and `.pb`.
590
            kwargs (remaining dictionary of keyword arguments, *optional*):
Steven Liu's avatar
Steven Liu committed
591
592
593
                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
                class). The overwritten components are passed directly to the pipelines `__init__` method. See example
                below for more information.
594
            variant (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
595
596
                Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
                loading `from_flax`.
597
598
599

        <Tip>

Steven Liu's avatar
Steven Liu committed
600
601
        To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
        `huggingface-cli login`.
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624

        </Tip>

        Examples:

        ```py
        >>> from diffusers import DiffusionPipeline

        >>> # Download pipeline from huggingface.co and cache.
        >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")

        >>> # Download pipeline that requires an authorization token
        >>> # For more information on access tokens, please refer to this section
        >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
        >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")

        >>> # Use a different scheduler
        >>> from diffusers import LMSDiscreteScheduler

        >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
        >>> pipeline.scheduler = scheduler
        ```
        """
625
        cache_dir = kwargs.pop("cache_dir", None)
626
627
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
628
629
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
630
        revision = kwargs.pop("revision", None)
631
        from_flax = kwargs.pop("from_flax", False)
632
633
634
635
636
637
        torch_dtype = kwargs.pop("torch_dtype", None)
        custom_pipeline = kwargs.pop("custom_pipeline", None)
        custom_revision = kwargs.pop("custom_revision", None)
        provider = kwargs.pop("provider", None)
        sess_options = kwargs.pop("sess_options", None)
        device_map = kwargs.pop("device_map", None)
638
639
640
        max_memory = kwargs.pop("max_memory", None)
        offload_folder = kwargs.pop("offload_folder", None)
        offload_state_dict = kwargs.pop("offload_state_dict", False)
641
        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
642
        variant = kwargs.pop("variant", None)
643
        use_safetensors = kwargs.pop("use_safetensors", None)
644
        use_onnx = kwargs.pop("use_onnx", None)
645
        load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
646

647
648
649
650
651
652
653
654
655
        if low_cpu_mem_usage and not is_accelerate_available():
            low_cpu_mem_usage = False
            logger.warning(
                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                " install accelerate\n```\n."
            )

656
657
658
659
660
661
        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

662
663
664
665
666
667
        if device_map is not None and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `device_map=None`."
            )

668
        if device_map is not None and not is_accelerate_available():
669
            raise NotImplementedError(
670
671
672
673
674
675
676
677
678
                "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
            )

        if device_map is not None and not isinstance(device_map, str):
            raise ValueError("`device_map` must be a string.")

        if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
            raise NotImplementedError(
                f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
679
680
            )

681
682
683
684
        if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
            if is_accelerate_version("<", "0.28.0"):
                raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")

685
686
687
688
689
690
        if low_cpu_mem_usage is False and device_map is not None:
            raise ValueError(
                f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
                " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
            )

691
692
693
        # 1. Download the checkpoints and configs
        # use snapshot download here to get it working from from_pretrained
        if not os.path.isdir(pretrained_model_name_or_path):
Patrick von Platen's avatar
Patrick von Platen committed
694
695
696
697
698
            if pretrained_model_name_or_path.count("/") > 1:
                raise ValueError(
                    f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
                    " is neither a valid local path nor a valid repo id. Please check the parameter."
                )
699
            cached_folder = cls.download(
700
701
702
703
704
                pretrained_model_name_or_path,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
705
                token=token,
706
                revision=revision,
707
                from_flax=from_flax,
708
                use_safetensors=use_safetensors,
709
                use_onnx=use_onnx,
710
                custom_pipeline=custom_pipeline,
711
                custom_revision=custom_revision,
712
                variant=variant,
713
                load_connected_pipeline=load_connected_pipeline,
714
                **kwargs,
715
716
717
718
            )
        else:
            cached_folder = pretrained_model_name_or_path

719
720
        config_dict = cls.load_config(cached_folder)

Patrick von Platen's avatar
Patrick von Platen committed
721
722
723
        # pop out "_ignore_files" as it is only needed for download
        config_dict.pop("_ignore_files", None)

724
725
726
        # 2. Define which model components should load variants
        # We retrieve the information by matching whether variant
        # model checkpoints exist in the subfolders
727
728
729
730
731
        model_variants = {}
        if variant is not None:
            for folder in os.listdir(cached_folder):
                folder_path = os.path.join(cached_folder, folder)
                is_folder = os.path.isdir(folder_path) and folder in config_dict
732
733
734
                variant_exists = is_folder and any(
                    p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
                )
735
736
737
                if variant_exists:
                    model_variants[folder] = variant

738
        # 3. Load the pipeline class, if using custom module then load it from the hub
739
        # if we load from explicit class, let's use it
740
741
742
743
744
745
746
747
748
        custom_class_name = None
        if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
            custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
        elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
            os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
        ):
            custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
            custom_class_name = config_dict["_class_name"][1]

749
        pipeline_class = _get_pipeline_class(
750
751
752
753
            cls,
            config_dict,
            load_connected_pipeline=load_connected_pipeline,
            custom_pipeline=custom_pipeline,
754
            class_name=custom_class_name,
755
756
            cache_dir=cache_dir,
            revision=custom_revision,
757
        )
758

759
760
761
        if device_map is not None and pipeline_class._load_connected_pipes:
            raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")

762
        # DEPRECATED: To be removed in 1.0.0
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
            version.parse(config_dict["_diffusers_version"]).base_version
        ) <= version.parse("0.5.1"):
            from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy

            pipeline_class = StableDiffusionInpaintPipelineLegacy

            deprecation_message = (
                "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
                f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
                " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
                " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
                f" checkpoint {pretrained_model_name_or_path} to the format of"
                " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
                " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
            )
            deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)

781
782
783
        # 4. Define expected modules given pipeline signature
        # and define non-None initialized modules (=`init_kwargs`)

784
785
786
787
788
789
790
791
792
        # some modules can be passed directly to the init
        # in this case they are already instantiated in `kwargs`
        # extract them here
        expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
        passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
        passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}

        init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)

793
794
795
796
797
798
        # define init kwargs and make sure that optional component modules are filtered out
        init_kwargs = {
            k: init_dict.pop(k)
            for k in optional_kwargs
            if k in init_dict and k not in pipeline_class._optional_components
        }
799
800
801
802
803
804
805
806
807
808
809
810
        init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

        # remove `null` components
        def load_module(name, value):
            if value[0] is None:
                return False
            if name in passed_class_obj and passed_class_obj[name] is None:
                return False
            return True

        init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

811
812
813
814
815
816
817
818
        # Special case: safety_checker must be loaded separately when using `from_flax`
        if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
            raise NotImplementedError(
                "The safety checker cannot be automatically loaded when loading weights `from_flax`."
                " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
                " separately if you need it."
            )

819
        # 5. Throw nice warnings / errors for fast accelerate loading
820
821
822
823
824
825
826
827
        if len(unused_kwargs) > 0:
            logger.warning(
                f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
            )

        # import it here to avoid circular import
        from diffusers import pipelines

828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
        # 6. device map delegation
        final_device_map = None
        if device_map is not None:
            final_device_map = _get_final_device_map(
                device_map=device_map,
                pipeline_class=pipeline_class,
                passed_class_obj=passed_class_obj,
                init_dict=init_dict,
                library=library,
                max_memory=max_memory,
                torch_dtype=torch_dtype,
                cached_folder=cached_folder,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
            )

        # 7. Load each module in the pipeline
        current_device_map = None
849
        for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
850
851
852
853
854
855
856
857
            if final_device_map is not None and len(final_device_map) > 0:
                component_device = final_device_map.get(name, None)
                if component_device is not None:
                    current_device_map = {"": component_device}
                else:
                    current_device_map = None

            # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
858
            class_name = class_name[4:] if class_name.startswith("Flax") else class_name
859

860
            # 7.2 Define all importable classes
861
            is_pipeline_module = hasattr(pipelines, library_name)
862
            importable_classes = ALL_IMPORTABLE_CLASSES
863
864
            loaded_sub_model = None

865
            # 7.3 Use passed sub model or load class_name from library_name
866
            if name in passed_class_obj:
867
868
869
870
871
                # if the model is in a pipeline module, then we load it from the pipeline
                # check that passed_class_obj has correct parent class
                maybe_raise_or_warn(
                    library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
                )
872
873
874

                loaded_sub_model = passed_class_obj[name]
            else:
875
876
877
878
879
880
881
882
883
884
885
                # load sub model
                loaded_sub_model = load_sub_model(
                    library_name=library_name,
                    class_name=class_name,
                    importable_classes=importable_classes,
                    pipelines=pipelines,
                    is_pipeline_module=is_pipeline_module,
                    pipeline_class=pipeline_class,
                    torch_dtype=torch_dtype,
                    provider=provider,
                    sess_options=sess_options,
886
                    device_map=current_device_map,
887
888
889
                    max_memory=max_memory,
                    offload_folder=offload_folder,
                    offload_state_dict=offload_state_dict,
890
891
892
893
894
895
                    model_variants=model_variants,
                    name=name,
                    from_flax=from_flax,
                    variant=variant,
                    low_cpu_mem_usage=low_cpu_mem_usage,
                    cached_folder=cached_folder,
896
                )
897
898
899
                logger.info(
                    f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
                )
900
901
902

            init_kwargs[name] = loaded_sub_model  # UNet(...), # DiffusionSchedule(...)

903
904
905
906
907
908
909
910
        if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
            modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
            connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
            load_kwargs = {
                "cache_dir": cache_dir,
                "force_download": force_download,
                "proxies": proxies,
                "local_files_only": local_files_only,
911
                "token": token,
912
913
914
915
916
917
918
919
920
921
922
923
924
925
                "revision": revision,
                "torch_dtype": torch_dtype,
                "custom_pipeline": custom_pipeline,
                "custom_revision": custom_revision,
                "provider": provider,
                "sess_options": sess_options,
                "device_map": device_map,
                "max_memory": max_memory,
                "offload_folder": offload_folder,
                "offload_state_dict": offload_state_dict,
                "low_cpu_mem_usage": low_cpu_mem_usage,
                "variant": variant,
                "use_safetensors": use_safetensors,
            }
926
927
928
929
930
931
932
933
934
935
936
937

            def get_connected_passed_kwargs(prefix):
                connected_passed_class_obj = {
                    k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
                }
                connected_passed_pipe_kwargs = {
                    k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
                }

                connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
                return connected_passed_kwargs

938
            connected_pipes = {
939
940
941
                prefix: DiffusionPipeline.from_pretrained(
                    repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
                )
942
943
944
945
946
947
948
949
950
951
                for prefix, repo_id in connected_pipes.items()
                if repo_id is not None
            }

            for prefix, connected_pipe in connected_pipes.items():
                # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
                init_kwargs.update(
                    {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
                )

952
        # 8. Potentially add passed objects if expected
953
954
955
956
957
958
959
960
961
962
963
964
        missing_modules = set(expected_modules) - set(init_kwargs.keys())
        passed_modules = list(passed_class_obj.keys())
        optional_modules = pipeline_class._optional_components
        if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
            for module in missing_modules:
                init_kwargs[module] = passed_class_obj.get(module, None)
        elif len(missing_modules) > 0:
            passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
            raise ValueError(
                f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
            )

965
        # 10. Instantiate the pipeline
966
        model = pipeline_class(**init_kwargs)
967

968
        # 11. Save where the model was instantiated from
969
        model.register_to_config(_name_or_path=pretrained_model_name_or_path)
970
971
        if device_map is not None:
            setattr(model, "hf_device_map", final_device_map)
972
973
        return model

974
975
976
977
    @property
    def name_or_path(self) -> str:
        return getattr(self.config, "_name_or_path", None)

978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
    @property
    def _execution_device(self):
        r"""
        Returns the device on which the pipeline's models will be executed. After calling
        [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
        Accelerate's module hooks.
        """
        for name, model in self.components.items():
            if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
                continue

            if not hasattr(model, "_hf_hook"):
                return self.device
            for module in model.modules():
                if (
                    hasattr(module, "_hf_hook")
                    and hasattr(module._hf_hook, "execution_device")
                    and module._hf_hook.execution_device is not None
                ):
                    return torch.device(module._hf_hook.execution_device)
        return self.device

1000
1001
1002
1003
1004
1005
    def remove_all_hooks(self):
        r"""
        Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
        """
        for _, model in self.components.items():
            if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
1006
                accelerate.hooks.remove_hook_from_module(model, recurse=True)
1007
1008
        self._all_hooks = []

1009
    def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1010
1011
1012
1013
1014
        r"""
        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1015
1016
1017
1018
1019
1020
1021

        Arguments:
            gpu_id (`int`, *optional*):
                The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
            device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
                The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
                default to "cuda".
1022
        """
1023
1024
1025
1026
1027
1028
        is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
        if is_pipeline_device_mapped:
            raise ValueError(
                "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
            )

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        if self.model_cpu_offload_seq is None:
            raise ValueError(
                "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
            )

        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
            from accelerate import cpu_offload_with_hook
        else:
            raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

1039
1040
        self.remove_all_hooks()

1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        torch_device = torch.device(device)
        device_index = torch_device.index

        if gpu_id is not None and device_index is not None:
            raise ValueError(
                f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
                f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
            )

        # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1051
        self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
1052
1053
1054

        device_type = torch_device.type
        device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1055
        self._offload_device = device
1056

1057
1058
1059
1060
        self.to("cpu", silence_dtype_warnings=True)
        device_mod = getattr(torch, device.type, None)
        if hasattr(device_mod, "empty_cache") and device_mod.is_available():
            device_mod.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)
1061
1062
1063

        all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}

1064
        self._all_hooks = []
1065
1066
        hook = None
        for model_str in self.model_cpu_offload_seq.split("->"):
1067
            model = all_model_components.pop(model_str, None)
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
            if not isinstance(model, torch.nn.Module):
                continue

            _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
            self._all_hooks.append(hook)

        # CPU offload models that are not in the seq chain unless they are explicitly excluded
        # these models will stay on CPU until maybe_free_model_hooks is called
        # some models cannot be in the seq chain because they are iteratively called, such as controlnet
        for name, model in all_model_components.items():
            if not isinstance(model, torch.nn.Module):
                continue

            if name in self._exclude_from_cpu_offload:
                model.to(device)
            else:
                _, hook = cpu_offload_with_hook(model, device)
                self._all_hooks.append(hook)

    def maybe_free_model_hooks(self):
        r"""
1089
1090
1091
1092
        Function that offloads all components, removes all model hooks that were added when using
        `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
        is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
        functions correctly when applying enable_model_cpu_offload.
1093
1094
1095
1096
1097
1098
        """
        if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
            # `enable_model_cpu_offload` has not be called, so silently do nothing
            return

        # make sure the model is in the same state as before calling it
1099
        self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
1100

1101
    def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1102
        r"""
1103
1104
1105
1106
        Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
        dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
        and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
        method called. Offloading happens on a submodule basis. Memory savings are higher than with
1107
        `enable_model_cpu_offload`, but performance is lower.
1108
1109
1110
1111
1112
1113
1114

        Arguments:
            gpu_id (`int`, *optional*):
                The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
            device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
                The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
                default to "cuda".
1115
1116
1117
1118
1119
        """
        if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
            from accelerate import cpu_offload
        else:
            raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
1120
        self.remove_all_hooks()
1121

1122
1123
1124
1125
1126
1127
        is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
        if is_pipeline_device_mapped:
            raise ValueError(
                "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
            )

1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        torch_device = torch.device(device)
        device_index = torch_device.index

        if gpu_id is not None and device_index is not None:
            raise ValueError(
                f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
                f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
            )

        # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1138
        self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
1139
1140
1141

        device_type = torch_device.type
        device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1142
        self._offload_device = device
1143
1144
1145

        if self.device.type != "cpu":
            self.to("cpu", silence_dtype_warnings=True)
1146
1147
1148
            device_mod = getattr(torch, self.device.type, None)
            if hasattr(device_mod, "empty_cache") and device_mod.is_available():
                device_mod.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161

        for name, model in self.components.items():
            if not isinstance(model, torch.nn.Module):
                continue

            if name in self._exclude_from_cpu_offload:
                model.to(device)
            else:
                # make sure to offload buffers if not all high level weights
                # are of type nn.Module
                offload_buffers = len(model._parameters) > 0
                cpu_offload(model, device, offload_buffers=offload_buffers)

1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
    def reset_device_map(self):
        r"""
        Resets the device maps (if any) to None.
        """
        if self.hf_device_map is None:
            return
        else:
            self.remove_all_hooks()
            for name, component in self.components.items():
                if isinstance(component, torch.nn.Module):
                    component.to("cpu")
            self.hf_device_map = None

1175
    @classmethod
1176
    @validate_hf_hub_args
1177
1178
    def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
        r"""
Steven Liu's avatar
Steven Liu committed
1179
        Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights.
1180
1181

        Parameters:
Steven Liu's avatar
Steven Liu committed
1182
            pretrained_model_name (`str` or `os.PathLike`, *optional*):
Steven Liu's avatar
Steven Liu committed
1183
                A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
Steven Liu's avatar
Steven Liu committed
1184
                hosted on the Hub.
1185
1186
1187
            custom_pipeline (`str`, *optional*):
                Can be either:

Steven Liu's avatar
Steven Liu committed
1188
                    - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained
Steven Liu's avatar
Steven Liu committed
1189
1190
                      pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines
                      the custom pipeline.
1191
1192

                    - A string, the *file name* of a community pipeline hosted on GitHub under
Steven Liu's avatar
Steven Liu committed
1193
1194
1195
1196
                      [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
                      names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
                      instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
                      current `main` branch of GitHub.
1197

Steven Liu's avatar
Steven Liu committed
1198
1199
                    - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
                      must contain a file called `pipeline.py` that defines the custom pipeline.
1200

Steven Liu's avatar
Steven Liu committed
1201
                <Tip warning={true}>
1202

Steven Liu's avatar
Steven Liu committed
1203
                🧪 This is an experimental feature and may change in the future.
1204

Steven Liu's avatar
Steven Liu committed
1205
                </Tip>
1206

Steven Liu's avatar
Steven Liu committed
1207
1208
                For more information on how to load and create custom pipelines, take a look at [How to contribute a
                community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
1209
1210
1211
1212

            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
1213

1214
            proxies (`Dict[str, str]`, *optional*):
Steven Liu's avatar
Steven Liu committed
1215
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1216
1217
1218
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
Steven Liu's avatar
Steven Liu committed
1219
1220
1221
            local_files_only (`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
1222
            token (`str` or *bool*, *optional*):
Steven Liu's avatar
Steven Liu committed
1223
1224
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
1225
            revision (`str`, *optional*, defaults to `"main"`):
Steven Liu's avatar
Steven Liu committed
1226
1227
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
Steven Liu's avatar
Steven Liu committed
1228
            custom_revision (`str`, *optional*, defaults to `"main"`):
1229
                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
Steven Liu's avatar
Steven Liu committed
1230
1231
                `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
                custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
1232
            mirror (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
1233
1234
1235
                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
                information.
1236
            variant (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
1237
1238
                Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
                loading `from_flax`.
1239
1240
1241
1242
1243
1244
1245
1246
1247
            use_safetensors (`bool`, *optional*, defaults to `None`):
                If set to `None`, the safetensors weights are downloaded if they're available **and** if the
                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
                weights. If set to `False`, safetensors weights are not loaded.
            use_onnx (`bool`, *optional*, defaults to `False`):
                If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
                will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
                `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
                with `.onnx` and `.pb`.
1248
1249
1250
1251
            trust_remote_code (`bool`, *optional*, defaults to `False`):
                Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
                option should only be set to `True` for repositories you trust and in which you have read the code, as
                it will execute code present on the Hub on your local machine.
Steven Liu's avatar
Steven Liu committed
1252
1253
1254
1255

        Returns:
            `os.PathLike`:
                A path to the downloaded pipeline.
1256
1257
1258

        <Tip>

Steven Liu's avatar
Steven Liu committed
1259
1260
        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
        `huggingface-cli login`.
1261
1262
1263
1264

        </Tip>

        """
1265
        cache_dir = kwargs.pop("cache_dir", None)
1266
1267
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
1268
1269
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
1270
1271
1272
        revision = kwargs.pop("revision", None)
        from_flax = kwargs.pop("from_flax", False)
        custom_pipeline = kwargs.pop("custom_pipeline", None)
1273
        custom_revision = kwargs.pop("custom_revision", None)
1274
        variant = kwargs.pop("variant", None)
1275
        use_safetensors = kwargs.pop("use_safetensors", None)
1276
        use_onnx = kwargs.pop("use_onnx", None)
1277
        load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
1278
        trust_remote_code = kwargs.pop("trust_remote_code", False)
1279
1280
1281

        allow_pickle = False
        if use_safetensors is None:
1282
            use_safetensors = True
1283
            allow_pickle = True
1284
1285
1286
1287

        allow_patterns = None
        ignore_patterns = None

1288
        model_info_call_error: Optional[Exception] = None
1289
1290
        if not local_files_only:
            try:
1291
                info = model_info(pretrained_model_name, token=token, revision=revision)
1292
            except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
1293
                logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
1294
                local_files_only = True
1295
                model_info_call_error = e  # save error to reraise it if model is not cached locally
1296

1297
1298
1299
1300
1301
        if not local_files_only:
            config_file = hf_hub_download(
                pretrained_model_name,
                cls.config_name,
                cache_dir=cache_dir,
1302
                revision=revision,
1303
1304
                proxies=proxies,
                force_download=force_download,
1305
                token=token,
1306
1307
1308
            )

            config_dict = cls._dict_from_json_file(config_file)
Patrick von Platen's avatar
Patrick von Platen committed
1309
1310
            ignore_filenames = config_dict.pop("_ignore_files", [])

1311
            # retrieve all folder_names that contain relevant files
1312
            folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
1313

1314
            filenames = {sibling.rfilename for sibling in info.siblings}
1315
1316
            model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

1317
1318
1319
1320
1321
1322
1323
1324
            diffusers_module = importlib.import_module(__name__.split(".")[0])
            pipelines = getattr(diffusers_module, "pipelines")

            # optionally create a custom component <> custom file mapping
            custom_components = {}
            for component in folder_names:
                module_candidate = config_dict[component][0]

1325
                if module_candidate is None or not isinstance(module_candidate, str):
1326
1327
                    continue

1328
1329
                # We compute candidate file path on the Hub. Do not use `os.path.join`.
                candidate_file = f"{component}/{module_candidate}.py"
1330
1331
1332
1333
1334
1335
1336
1337

                if candidate_file in filenames:
                    custom_components[component] = module_candidate
                elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
                    raise ValueError(
                        f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
                    )

1338
1339
1340
1341
            if len(variant_filenames) == 0 and variant is not None:
                deprecation_message = (
                    f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
                    f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
1342
                    "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1343
1344
                    "modeling files is deprecated."
                )
1345
                deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
1346

Patrick von Platen's avatar
Patrick von Platen committed
1347
1348
1349
1350
            # remove ignored filenames
            model_filenames = set(model_filenames) - set(ignore_filenames)
            variant_filenames = set(variant_filenames) - set(ignore_filenames)

1351
1352
1353
            # if the whole pipeline is cached we don't have to ping the Hub
            if revision in DEPRECATED_REVISION_ARGS and version.parse(
                version.parse(__version__).base_version
Patrick von Platen's avatar
Patrick von Platen committed
1354
            ) >= version.parse("0.22.0"):
1355
                warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1356

1357
            model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1358

1359
1360
1361
1362
1363
            custom_class_name = None
            if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
                custom_pipeline = config_dict["_class_name"][0]
                custom_class_name = config_dict["_class_name"][1]

1364
1365
1366
1367
1368
            # all filenames compatible with variant will be added
            allow_patterns = list(model_filenames)

            # allow all patterns from non-model folders
            # this enables downloading schedulers, tokenizers, ...
1369
            allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
1370
1371
1372
1373
            # add custom component files
            allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
            # add custom pipeline file
            allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
1374
            # also allow downloading config.json files with the model
1375
            allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
1376
1377
1378
1379
1380
1381
1382
1383

            allow_patterns += [
                SCHEDULER_CONFIG_NAME,
                CONFIG_NAME,
                cls.config_name,
                CUSTOM_PIPELINE_FILE_NAME,
            ]

1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
            load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
            load_components_from_hub = len(custom_components) > 0

            if load_pipe_from_hub and not trust_remote_code:
                raise ValueError(
                    f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
                    f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
                    f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
                )

            if load_components_from_hub and not trust_remote_code:
                raise ValueError(
                    f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
                    f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
                    f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
                )

1401
1402
            # retrieve passed components that should not be downloaded
            pipeline_class = _get_pipeline_class(
1403
1404
1405
1406
                cls,
                config_dict,
                load_connected_pipeline=load_connected_pipeline,
                custom_pipeline=custom_pipeline,
1407
1408
1409
                repo_id=pretrained_model_name if load_pipe_from_hub else None,
                hub_revision=revision,
                class_name=custom_class_name,
1410
1411
                cache_dir=cache_dir,
                revision=custom_revision,
1412
1413
1414
1415
            )
            expected_components, _ = cls._get_signature_keys(pipeline_class)
            passed_components = [k for k in expected_components if k in kwargs]

1416
1417
1418
            if (
                use_safetensors
                and not allow_pickle
1419
                and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
1420
1421
            ):
                raise EnvironmentError(
1422
                    f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
1423
                )
1424
1425
            if from_flax:
                ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1426
            elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
1427
1428
                ignore_patterns = ["*.bin", "*.msgpack"]

1429
1430
1431
1432
                use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
                if not use_onnx:
                    ignore_patterns += ["*.onnx", "*.pb"]

1433
1434
                safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
                safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
1435
1436
1437
1438
                if (
                    len(safetensors_variant_filenames) > 0
                    and safetensors_model_filenames != safetensors_variant_filenames
                ):
1439
                    logger.warning(
1440
1441
1442
1443
1444
                        f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
                    )
            else:
                ignore_patterns = ["*.safetensors", "*.msgpack"]

1445
1446
1447
1448
                use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
                if not use_onnx:
                    ignore_patterns += ["*.onnx", "*.pb"]

1449
1450
                bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
                bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1451
                if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1452
                    logger.warning(
1453
1454
1455
                        f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
                    )

1456
1457
1458
1459
            # Don't download any objects that are passed
            allow_patterns = [
                p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
            ]
1460
1461
1462
1463

            if pipeline_class._load_connected_pipes:
                allow_patterns.append("README.md")

1464
1465
            # Don't download index files of forbidden patterns either
            ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
1466
1467
1468
1469
1470
            re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
            re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]

            expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
            expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
1471

1472
1473
            snapshot_folder = Path(config_file).parent
            pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
1474

1475
            if pipeline_is_cached and not force_download:
1476
1477
1478
                # if the pipeline is cached, we can directly return it
                # else call snapshot_download
                return snapshot_folder
1479

1480
1481
1482
        user_agent = {"pipeline_class": cls.__name__}
        if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
            user_agent["custom_pipeline"] = custom_pipeline
1483
1484

        # download all allow_patterns - ignore_patterns
1485
        try:
1486
            cached_folder = snapshot_download(
1487
1488
1489
1490
                pretrained_model_name,
                cache_dir=cache_dir,
                proxies=proxies,
                local_files_only=local_files_only,
1491
                token=token,
1492
1493
1494
1495
1496
                revision=revision,
                allow_patterns=allow_patterns,
                ignore_patterns=ignore_patterns,
                user_agent=user_agent,
            )
1497

1498
1499
            # retrieve pipeline class from local file
            cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1500
            cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
1501

1502
1503
            diffusers_module = importlib.import_module(__name__.split(".")[0])
            pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
1504
1505

            if pipeline_class is not None and pipeline_class._load_connected_pipes:
1506
1507
1508
                modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
                connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
                for connected_pipe_repo_id in connected_pipes:
1509
1510
1511
1512
1513
                    download_kwargs = {
                        "cache_dir": cache_dir,
                        "force_download": force_download,
                        "proxies": proxies,
                        "local_files_only": local_files_only,
1514
                        "token": token,
1515
1516
1517
1518
                        "variant": variant,
                        "use_safetensors": use_safetensors,
                    }
                    DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
1519
1520
1521

            return cached_folder

1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
        except FileNotFoundError:
            # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
            # This can happen in two cases:
            # 1. If the user passed `local_files_only=True`                    => we raise the error directly
            # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
            if model_info_call_error is None:
                # 1. user passed `local_files_only=True`
                raise
            else:
                # 2. we forced `local_files_only=True` when `model_info` failed
                raise EnvironmentError(
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
1533
                    f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
1534
1535
1536
                    " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
                    " above."
                ) from model_info_call_error
1537

1538
1539
    @classmethod
    def _get_signature_keys(cls, obj):
1540
1541
1542
        parameters = inspect.signature(obj.__init__).parameters
        required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
        optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
1543
        expected_modules = set(required_parameters.keys()) - {"self"}
1544
1545
1546
1547
1548
1549
1550

        optional_names = list(optional_parameters)
        for name in optional_names:
            if name in cls._optional_components:
                expected_modules.add(name)
                optional_parameters.remove(name)

1551
1552
        return expected_modules, optional_parameters

1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
    @classmethod
    def _get_signature_types(cls):
        signature_types = {}
        for k, v in inspect.signature(cls.__init__).parameters.items():
            if inspect.isclass(v.annotation):
                signature_types[k] = (v.annotation,)
            elif get_origin(v.annotation) == Union:
                signature_types[k] = get_args(v.annotation)
            else:
                logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
        return signature_types

1565
1566
1567
1568
    @property
    def components(self) -> Dict[str, Any]:
        r"""
        The `self.components` property can be useful to run different pipelines with the same weights and
Steven Liu's avatar
Steven Liu committed
1569
1570
1571
1572
        configurations without reallocating additional memory.

        Returns (`dict`):
            A dictionary containing all the modules needed to initialize the pipeline.
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595

        Examples:

        ```py
        >>> from diffusers import (
        ...     StableDiffusionPipeline,
        ...     StableDiffusionImg2ImgPipeline,
        ...     StableDiffusionInpaintPipeline,
        ... )

        >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
        >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
        ```
        """
        expected_modules, optional_parameters = self._get_signature_keys(self)
        components = {
            k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
        }

        if set(components.keys()) != expected_modules:
            raise ValueError(
                f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
1596
                f" {expected_modules} to be defined, but {components.keys()} are defined."
1597
1598
1599
1600
1601
1602
1603
            )

        return components

    @staticmethod
    def numpy_to_pil(images):
        """
Steven Liu's avatar
Steven Liu committed
1604
        Convert a NumPy image or a batch of images to a PIL image.
1605
        """
Patrick von Platen's avatar
Patrick von Platen committed
1606
        return numpy_to_pil(images)
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625

    def progress_bar(self, iterable=None, total=None):
        if not hasattr(self, "_progress_bar_config"):
            self._progress_bar_config = {}
        elif not isinstance(self._progress_bar_config, dict):
            raise ValueError(
                f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
            )

        if iterable is not None:
            return tqdm(iterable, **self._progress_bar_config)
        elif total is not None:
            return tqdm(total=total, **self._progress_bar_config)
        else:
            raise ValueError("Either `total` or `iterable` has to be defined.")

    def set_progress_bar_config(self, **kwargs):
        self._progress_bar_config = kwargs

1626
    def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
1627
        r"""
1628
1629
1630
        Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this
        option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
        up during training is not guaranteed.
1631

Steven Liu's avatar
Steven Liu committed
1632
        <Tip warning={true}>
1633

Steven Liu's avatar
Steven Liu committed
1634
1635
1636
1637
        ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
        precedent.

        </Tip>
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657

        Parameters:
            attention_op (`Callable`, *optional*):
                Override the default `None` operator for use as `op` argument to the
                [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
                function of xFormers.

        Examples:

        ```py
        >>> import torch
        >>> from diffusers import DiffusionPipeline
        >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

        >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")
        >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
        >>> # Workaround for not accepting attention shape using VAE for Flash Attention
        >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
        ```
1658
        """
1659
        self.set_use_memory_efficient_attention_xformers(True, attention_op)
1660
1661
1662

    def disable_xformers_memory_efficient_attention(self):
        r"""
Steven Liu's avatar
Steven Liu committed
1663
        Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
1664
1665
1666
        """
        self.set_use_memory_efficient_attention_xformers(False)

1667
1668
1669
    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op: Optional[Callable] = None
    ) -> None:
1670
1671
1672
1673
1674
        # Recursively walk through all the children.
        # Any children which exposes the set_use_memory_efficient_attention_xformers method
        # gets the message
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
1675
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)
1676
1677
1678
1679

            for child in module.children():
                fn_recursive_set_mem_eff(child)

1680
1681
1682
        module_names, _ = self._get_signature_keys(self)
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]
1683

1684
1685
        for module in modules:
            fn_recursive_set_mem_eff(module)
1686
1687
1688

    def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
        r"""
1689
        Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
        in slices to compute attention in several steps. For more than one attention head, the computation is performed
        sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.

        <Tip warning={true}>

        ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
        2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
        this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!

        </Tip>
1700
1701
1702
1703

        Args:
            slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
Alexander Pivovarov's avatar
Alexander Pivovarov committed
1704
                `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
1705
1706
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723

        Examples:

        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionPipeline

        >>> pipe = StableDiffusionPipeline.from_pretrained(
        ...     "runwayml/stable-diffusion-v1-5",
        ...     torch_dtype=torch.float16,
        ...     use_safetensors=True,
        ... )

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> pipe.enable_attention_slicing()
        >>> image = pipe(prompt).images[0]
        ```
1724
1725
1726
1727
1728
        """
        self.set_attention_slice(slice_size)

    def disable_attention_slicing(self):
        r"""
Steven Liu's avatar
Steven Liu committed
1729
1730
        Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is
        computed in one step.
1731
1732
1733
1734
1735
        """
        # set slice_size = `None` to disable `attention slicing`
        self.enable_attention_slicing(None)

    def set_attention_slice(self, slice_size: Optional[int]):
1736
1737
1738
        module_names, _ = self._get_signature_keys(self)
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
1739

1740
1741
        for module in modules:
            module.set_attention_slice(slice_size)
1742

1743
1744
1745
    @classmethod
    def from_pipe(cls, pipeline, **kwargs):
        r"""
1746
1747
        Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
        pipeline components without reallocating additional memory.
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805

        Arguments:
            pipeline (`DiffusionPipeline`):
                The pipeline from which to create a new pipeline.

        Returns:
            `DiffusionPipeline`:
                A new pipeline with the same weights and configurations as `pipeline`.

        Examples:

        ```py
        >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline

        >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
        >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
        ```
        """

        original_config = dict(pipeline.config)
        torch_dtype = kwargs.pop("torch_dtype", None)

        # derive the pipeline class to instantiate
        custom_pipeline = kwargs.pop("custom_pipeline", None)
        custom_revision = kwargs.pop("custom_revision", None)

        if custom_pipeline is not None:
            pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
        else:
            pipeline_class = cls

        expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
        # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
        # e.g. `image_encoder` for StableDiffusionPipeline
        parameters = inspect.signature(cls.__init__).parameters
        true_optional_modules = set(
            {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
        )

        # get the class of each component based on its type hint
        # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
        component_types = pipeline_class._get_signature_types()

        pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
        # allow users pass modules in `kwargs` to override the original pipeline's components
        passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}

        original_class_obj = {}
        for name, component in pipeline.components.items():
            if name in expected_modules and name not in passed_class_obj:
                # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
                if (
                    not isinstance(component, ModelMixin)
                    or type(component) in component_types[name]
                    or (component is None and name in cls._optional_components)
                ):
                    original_class_obj[name] = component
                else:
1806
                    logger.warning(
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
                        f"component {name} is not switched over to new pipeline because type does not match the expected."
                        f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
                        f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
                    )

        # allow users pass optional kwargs to override the original pipelines config attribute
        passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
        original_pipe_kwargs = {
            k: original_config[k]
            for k in original_config.keys()
            if k in optional_kwargs and k not in passed_pipe_kwargs
        }

        # config attribute that were not expected by pipeline is stored as its private attribute
        # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
        # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
        additional_pipe_kwargs = [
            k[1:]
            for k in original_config.keys()
            if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
        ]
        for k in additional_pipe_kwargs:
            original_pipe_kwargs[k] = original_config.pop(f"_{k}")

        pipeline_kwargs = {
            **passed_class_obj,
            **original_class_obj,
            **passed_pipe_kwargs,
            **original_pipe_kwargs,
            **kwargs,
        }

        # store unused config as private attribute in the new pipeline
        unused_original_config = {
            f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
        }

        missing_modules = (
            set(expected_modules)
            - set(pipeline._optional_components)
            - set(pipeline_kwargs.keys())
            - set(true_optional_modules)
        )

        if len(missing_modules) > 0:
            raise ValueError(
                f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
            )

        new_pipeline = pipeline_class(**pipeline_kwargs)
        if pretrained_model_name_or_path is not None:
            new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
        new_pipeline.register_to_config(**unused_original_config)

        if torch_dtype is not None:
            new_pipeline.to(dtype=torch_dtype)

        return new_pipeline

1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928

class StableDiffusionMixin:
    r"""
    Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
    """

    def enable_vae_slicing(self):
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        r"""
        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_slicing()

    def enable_vae_tiling(self):
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.
        """
        self.vae.enable_tiling()

    def disable_vae_tiling(self):
        r"""
        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_tiling()

    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
        r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.

        The suffixes after the scaling factors represent the stages where they are being applied.

        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

        Args:
            s1 (`float`):
                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                mitigate "oversmoothing effect" in the enhanced denoising process.
            s2 (`float`):
                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                mitigate "oversmoothing effect" in the enhanced denoising process.
            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
        """
        if not hasattr(self, "unet"):
            raise ValueError("The pipeline must have `unet` for using FreeU.")
        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)

    def disable_freeu(self):
        """Disables the FreeU mechanism if enabled."""
        self.unet.disable_freeu()

    def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
        """
1929
1930
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        Args:
            unet (`bool`, defaults to `True`): To apply fusion on the UNet.
            vae (`bool`, defaults to `True`): To apply fusion on the VAE.
        """
        self.fusing_unet = False
        self.fusing_vae = False

        if unet:
            self.fusing_unet = True
            self.unet.fuse_qkv_projections()
            self.unet.set_attn_processor(FusedAttnProcessor2_0())

        if vae:
            if not isinstance(self.vae, AutoencoderKL):
                raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")

            self.fusing_vae = True
            self.vae.fuse_qkv_projections()
            self.vae.set_attn_processor(FusedAttnProcessor2_0())

    def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
        """Disable QKV projection fusion if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        Args:
            unet (`bool`, defaults to `True`): To apply fusion on the UNet.
            vae (`bool`, defaults to `True`): To apply fusion on the VAE.

        """
        if unet:
            if not self.fusing_unet:
                logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
            else:
                self.unet.unfuse_qkv_projections()
                self.fusing_unet = False

        if vae:
            if not self.fusing_vae:
                logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
            else:
                self.vae.unfuse_qkv_projections()
                self.fusing_vae = False