pipeline_utils.py 94.2 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2024 The HuggingFace Inc. team.
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
import enum
17
import fnmatch
18
19
20
import importlib
import inspect
import os
21
import re
22
import sys
23
24
from dataclasses import dataclass
from pathlib import Path
25
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
26
27

import numpy as np
Anh71me's avatar
Anh71me committed
28
import PIL.Image
29
import requests
30
import torch
31
32
33
34
35
36
37
from huggingface_hub import (
    ModelCard,
    create_repo,
    hf_hub_download,
    model_info,
    snapshot_download,
)
38
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
39
from packaging import version
40
from requests.exceptions import HTTPError
41
42
from tqdm.auto import tqdm

43
from .. import __version__
44
from ..configuration_utils import ConfigMixin
45
46
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
47
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
48
from ..quantizers.bitsandbytes.utils import _check_bnb_status
49
50
51
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
    CONFIG_NAME,
52
    DEPRECATED_REVISION_ARGS,
53
    BaseOutput,
54
    PushToHubMixin,
55
    is_accelerate_available,
56
    is_accelerate_version,
Mengqing Cao's avatar
Mengqing Cao committed
57
    is_torch_npu_available,
58
    is_torch_version,
59
    is_transformers_version,
60
    logging,
Patrick von Platen's avatar
Patrick von Platen committed
61
    numpy_to_pil,
62
)
63
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
Dhruv Nair's avatar
Dhruv Nair committed
64
from ..utils.torch_utils import is_compiled_module
Mengqing Cao's avatar
Mengqing Cao committed
65
66
67
68
69


if is_torch_npu_available():
    import torch_npu  # noqa: F401

70
71
72
73
74
75
from .pipeline_loading_utils import (
    ALL_IMPORTABLE_CLASSES,
    CONNECTED_PIPES_KEYS,
    CUSTOM_PIPELINE_FILE_NAME,
    LOADABLE_CLASSES,
    _fetch_class_library_tuple,
76
    _get_custom_components_and_folders,
77
    _get_custom_pipeline_class,
78
    _get_final_device_map,
79
    _get_ignore_patterns,
80
    _get_pipeline_class,
81
82
83
    _identify_model_variants,
    _maybe_raise_warning_for_inpainting,
    _resolve_custom_pipeline_and_cls,
84
    _unwrap_model,
85
    _update_init_kwargs_with_connected_pipeline,
86
87
88
89
90
    load_sub_model,
    maybe_raise_or_warn,
    variant_compatible_siblings,
    warn_deprecated_model_variant,
)
91
92


93
94
95
96
if is_accelerate_available():
    import accelerate


97
98
99
LIBRARIES = []
for library in LOADABLE_CLASSES:
    LIBRARIES.append(library)
100

101
102
SUPPORTED_DEVICE_MAP = ["balanced"]

103
104
105
106
107
108
109
110
111
112
logger = logging.get_logger(__name__)


@dataclass
class ImagePipelineOutput(BaseOutput):
    """
    Output class for image pipelines.

    Args:
        images (`List[PIL.Image.Image]` or `np.ndarray`)
Steven Liu's avatar
Steven Liu committed
113
114
            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
            num_channels)`.
115
116
117
118
119
120
121
122
123
124
125
126
    """

    images: Union[List[PIL.Image.Image], np.ndarray]


@dataclass
class AudioPipelineOutput(BaseOutput):
    """
    Output class for audio pipelines.

    Args:
        audios (`np.ndarray`)
Steven Liu's avatar
Steven Liu committed
127
            List of denoised audio samples of a NumPy array of shape `(batch_size, num_channels, sample_rate)`.
128
129
130
131
132
    """

    audios: np.ndarray


133
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
134
    r"""
Steven Liu's avatar
Steven Liu committed
135
    Base class for all pipelines.
136

Steven Liu's avatar
Steven Liu committed
137
138
    [`DiffusionPipeline`] stores all components (models, schedulers, and processors) for diffusion pipelines and
    provides methods for loading, downloading and saving models. It also includes methods to:
139
140

        - move all PyTorch modules to the device of your choice
141
        - enable/disable the progress bar for the denoising iteration
142
143
144

    Class attributes:

Steven Liu's avatar
Steven Liu committed
145
146
        - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the
          diffusion pipeline's components.
147
        - **_optional_components** (`List[str]`) -- List of all optional components that don't have to be passed to the
Steven Liu's avatar
Steven Liu committed
148
          pipeline to function (should be overridden by subclasses).
149
    """
150

151
    config_name = "model_index.json"
152
    model_cpu_offload_seq = None
153
    hf_device_map = None
154
    _optional_components = []
155
    _exclude_from_cpu_offload = []
156
    _load_connected_pipes = False
157
    _is_onnx = False
158
159
160
161

    def register_modules(self, **kwargs):
        for name, module in kwargs.items():
            # retrieve library
162
            if module is None or isinstance(module, (tuple, list)) and module[0] is None:
163
164
                register_dict = {name: (None, None)}
            else:
165
                library, class_name = _fetch_class_library_tuple(module)
166
167
168
169
170
171
172
173
                register_dict = {name: (library, class_name)}

            # save model index config
            self.register_to_config(**register_dict)

            # set models
            setattr(self, name, module)

174
    def __setattr__(self, name: str, value: Any):
175
        if name in self.__dict__ and hasattr(self.config, name):
176
177
            # We need to overwrite the config if name exists in config
            if isinstance(getattr(self.config, name), (tuple, list)):
178
                if value is not None and self.config[name][0] is not None:
179
                    class_library_tuple = _fetch_class_library_tuple(value)
180
181
182
183
184
185
186
187
188
                else:
                    class_library_tuple = (None, None)

                self.register_to_config(**{name: class_library_tuple})
            else:
                self.register_to_config(**{name: value})

        super().__setattr__(name, value)

189
190
191
    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
192
        safe_serialization: bool = True,
193
        variant: Optional[str] = None,
194
        max_shard_size: Optional[Union[int, str]] = None,
195
196
        push_to_hub: bool = False,
        **kwargs,
197
198
    ):
        """
Steven Liu's avatar
Steven Liu committed
199
200
201
        Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
        class implements both a save and loading method. The pipeline is easily reloaded using the
        [`~DiffusionPipeline.from_pretrained`] class method.
202
203
204

        Arguments:
            save_directory (`str` or `os.PathLike`):
Steven Liu's avatar
Steven Liu committed
205
                Directory to save a pipeline to. Will be created if it doesn't exist.
206
            safe_serialization (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
207
                Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
208
            variant (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
209
                If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
210
            max_shard_size (`int` or `str`, defaults to `None`):
211
212
213
214
215
216
                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
                If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
                period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
                This is to establish a common default size for this argument across different libraries in the Hugging
                Face ecosystem (`transformers`, and `accelerate`, for example).
217
218
219
220
221
222
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
223
224
        """
        model_index_dict = dict(self.config)
225
226
        model_index_dict.pop("_class_name", None)
        model_index_dict.pop("_diffusers_version", None)
227
        model_index_dict.pop("_module", None)
228
        model_index_dict.pop("_name_or_path", None)
229

230
231
        if push_to_hub:
            commit_message = kwargs.pop("commit_message", None)
232
            private = kwargs.pop("private", None)
233
234
235
236
237
            create_pr = kwargs.pop("create_pr", False)
            token = kwargs.pop("token", None)
            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
            repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

238
239
240
241
242
243
244
245
246
247
248
249
250
251
        expected_modules, optional_kwargs = self._get_signature_keys(self)

        def is_saveable_module(name, value):
            if name not in expected_modules:
                return False
            if name in self._optional_components and value[0] is None:
                return False
            return True

        model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
        for pipeline_component_name in model_index_dict.keys():
            sub_model = getattr(self, pipeline_component_name)
            model_cls = sub_model.__class__

252
253
254
            # Dynamo wraps the original model in a private class.
            # I didn't find a public API to get the original class.
            if is_compiled_module(sub_model):
255
                sub_model = _unwrap_model(sub_model)
256
257
                model_cls = sub_model.__class__

258
259
260
            save_method_name = None
            # search for the model's base class in LOADABLE_CLASSES
            for library_name, library_classes in LOADABLE_CLASSES.items():
261
262
263
264
265
266
267
                if library_name in sys.modules:
                    library = importlib.import_module(library_name)
                else:
                    logger.info(
                        f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
                    )

268
269
270
271
272
273
274
275
276
                for base_class, save_load_methods in library_classes.items():
                    class_candidate = getattr(library, base_class, None)
                    if class_candidate is not None and issubclass(model_cls, class_candidate):
                        # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
                        save_method_name = save_load_methods[0]
                        break
                if save_method_name is not None:
                    break

277
            if save_method_name is None:
278
279
280
                logger.warning(
                    f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved."
                )
281
282
283
284
                # make sure that unsaveable components are not tried to be loaded afterward
                self.register_to_config(**{pipeline_component_name: (None, None)})
                continue

285
286
287
288
289
            save_method = getattr(sub_model, save_method_name)

            # Call the save method with the argument safe_serialization only if it's supported
            save_method_signature = inspect.signature(save_method)
            save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
290
            save_method_accept_variant = "variant" in save_method_signature.parameters
291
            save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
292
293

            save_kwargs = {}
294
            if save_method_accept_safe:
295
296
297
                save_kwargs["safe_serialization"] = safe_serialization
            if save_method_accept_variant:
                save_kwargs["variant"] = variant
298
299
            if save_method_accept_max_shard_size and max_shard_size is not None:
                # max_shard_size is expected to not be None in ModelMixin
300
                save_kwargs["max_shard_size"] = max_shard_size
301
302

            save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
303

304
305
306
        # finally save the config
        self.save_config(save_directory)

307
        if push_to_hub:
308
309
310
311
312
            # Create a new empty model card and eventually tag it
            model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
            model_card = populate_model_card(model_card)
            model_card.save(os.path.join(save_directory, "README.md"))

313
314
315
316
317
318
319
320
            self._upload_folder(
                save_directory,
                repo_id,
                token=token,
                commit_message=commit_message,
                create_pr=create_pr,
            )

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    def to(self, *args, **kwargs):
        r"""
        Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
        arguments of `self.to(*args, **kwargs).`

        <Tip>

            If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
            the returned pipeline is a copy of self with the desired torch.dtype and torch.device.

        </Tip>


        Here are the ways to call `to`:

        - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
          [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
        - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
          [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
        - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
          specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
          [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)

        Arguments:
            dtype (`torch.dtype`, *optional*):
                Returns a pipeline with the specified
                [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
            device (`torch.Device`, *optional*):
                Returns a pipeline with the specified
                [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
            silence_dtype_warnings (`str`, *optional*, defaults to `False`):
                Whether to omit warnings if the target `dtype` is not compatible with the target `device`.

        Returns:
            [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
        """
357
358
        dtype = kwargs.pop("dtype", None)
        device = kwargs.pop("device", None)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)

        dtype_arg = None
        device_arg = None
        if len(args) == 1:
            if isinstance(args[0], torch.dtype):
                dtype_arg = args[0]
            else:
                device_arg = torch.device(args[0]) if args[0] is not None else None
        elif len(args) == 2:
            if isinstance(args[0], torch.dtype):
                raise ValueError(
                    "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
                )
            device_arg = torch.device(args[0]) if args[0] is not None else None
            dtype_arg = args[1]
        elif len(args) > 2:
            raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")

        if dtype is not None and dtype_arg is not None:
            raise ValueError(
                "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
            )

        dtype = dtype or dtype_arg

        if device is not None and device_arg is not None:
            raise ValueError(
                "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
            )

        device = device or device_arg
391
        pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
392

393
394
395
396
397
        # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
        def module_is_sequentially_offloaded(module):
            if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
                return False

398
399
400
401
402
            return hasattr(module, "_hf_hook") and (
                isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
                or hasattr(module._hf_hook, "hooks")
                and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
            )
403
404
405
406
407
408
409
410
411
412
413

        def module_is_offloaded(module):
            if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
                return False

            return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

        # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
        pipeline_is_sequentially_offloaded = any(
            module_is_sequentially_offloaded(module) for _, module in self.components.items()
        )
414
415
416
417
418
419
420

        is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
        if is_pipeline_device_mapped:
            raise ValueError(
                "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
            )

421
422
423
424
425
426
427
428
429
430
        if device and torch.device(device).type == "cuda":
            if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
                raise ValueError(
                    "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
                )
            # PR: https://github.com/huggingface/accelerate/pull/3223/
            elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
                raise ValueError(
                    "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
                )
431
432
433

        # Display a warning in this case (the operation succeeds but the benefits are lost)
        pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
434
        if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
435
436
437
438
            logger.warning(
                f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
            )

439
        module_names, _ = self._get_signature_keys(self)
440
441
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]
442

443
        is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
444
        for module in modules:
445
            _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
Patrick von Platen's avatar
Patrick von Platen committed
446

447
            if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
Patrick von Platen's avatar
Patrick von Platen committed
448
                logger.warning(
449
                    f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
Patrick von Platen's avatar
Patrick von Platen committed
450
451
                )

452
            if is_loaded_in_8bit_bnb and device is not None:
Patrick von Platen's avatar
Patrick von Platen committed
453
                logger.warning(
454
                    f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
Patrick von Platen's avatar
Patrick von Platen committed
455
                )
456
457
458
459
460
461

            # This can happen for `transformer` models. CPU placement was added in
            # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
            if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
                module.to(device=device)
            elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
462
                module.to(device, dtype)
Patrick von Platen's avatar
Patrick von Platen committed
463

464
465
            if (
                module.dtype == torch.float16
466
                and str(device) in ["cpu"]
467
468
469
470
                and not silence_dtype_warnings
                and not is_offloaded
            ):
                logger.warning(
471
                    "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
472
473
474
475
476
                    " is not recommended to move them to `cpu` as running them will fail. Please make"
                    " sure to use an accelerator to run the pipeline in inference, due to the lack of"
                    " support for`float16` operations on this device in PyTorch. Please, remove the"
                    " `torch_dtype=torch.float16` argument, or use another device for inference."
                )
477
478
479
480
481
482
483
484
        return self

    @property
    def device(self) -> torch.device:
        r"""
        Returns:
            `torch.device`: The torch device on which the pipeline is located.
        """
485
        module_names, _ = self._get_signature_keys(self)
486
487
488
489
490
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]

        for module in modules:
            return module.device
491

492
493
        return torch.device("cpu")

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    @property
    def dtype(self) -> torch.dtype:
        r"""
        Returns:
            `torch.dtype`: The torch dtype on which the pipeline is located.
        """
        module_names, _ = self._get_signature_keys(self)
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]

        for module in modules:
            return module.dtype

        return torch.float32

509
    @classmethod
510
    @validate_hf_hub_args
511
512
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
        r"""
Steven Liu's avatar
Steven Liu committed
513
        Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
514

Steven Liu's avatar
Steven Liu committed
515
        The pipeline is set in evaluation mode (`model.eval()`) by default.
516

Steven Liu's avatar
Steven Liu committed
517
        If you get the error message below, you need to finetune the weights for your downstream task:
518

Steven Liu's avatar
Steven Liu committed
519
        ```
520
        Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
Steven Liu's avatar
Steven Liu committed
521
522
523
        - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
        You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
        ```
524
525
526
527
528

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

Steven Liu's avatar
Steven Liu committed
529
530
531
532
533
                    - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
                      hosted on the Hub.
                    - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
                      saved using
                    [`~DiffusionPipeline.save_pretrained`].
534
            torch_dtype (`str` or `torch.dtype`, *optional*):
Steven Liu's avatar
Steven Liu committed
535
536
                Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
                dtype is automatically derived from the model's weights.
537
538
539
540
            custom_pipeline (`str`, *optional*):

                <Tip warning={true}>

Steven Liu's avatar
Steven Liu committed
541
                🧪 This is an experimental feature and may change in the future.
542
543
544
545
546

                </Tip>

                Can be either:

Steven Liu's avatar
Steven Liu committed
547
548
549
                    - A string, the *repo id* (for example `hf-internal-testing/diffusers-dummy-pipeline`) of a custom
                      pipeline hosted on the Hub. The repository must contain a file called pipeline.py that defines
                      the custom pipeline.
550
                    - A string, the *file name* of a community pipeline hosted on GitHub under
Steven Liu's avatar
Steven Liu committed
551
552
553
554
555
556
                      [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
                      names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
                      instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
                      current main branch of GitHub.
                    - A path to a directory (`./my_pipeline_directory/`) containing a custom pipeline. The directory
                      must contain a file called `pipeline.py` that defines the custom pipeline.
557
558
559
560
561
562
563

                For more information on how to load and create custom pipelines, please have a look at [Loading and
                Adding Custom
                Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
564
            cache_dir (`Union[str, os.PathLike]`, *optional*):
Steven Liu's avatar
Steven Liu committed
565
566
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
567

568
            proxies (`Dict[str, str]`, *optional*):
Steven Liu's avatar
Steven Liu committed
569
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
570
571
572
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
Steven Liu's avatar
Steven Liu committed
573
574
575
            local_files_only (`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
576
            token (`str` or *bool*, *optional*):
Steven Liu's avatar
Steven Liu committed
577
578
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
579
            revision (`str`, *optional*, defaults to `"main"`):
Steven Liu's avatar
Steven Liu committed
580
581
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
582
            custom_revision (`str`, *optional*):
583
                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
584
585
                `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
                version.
586
            mirror (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
587
588
589
                Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
                information.
590
            device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
Steven Liu's avatar
Steven Liu committed
591
592
                A map that specifies where each submodule should go. It doesn’t need to be defined for each
                parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
593
594
                same device.

Steven Liu's avatar
Steven Liu committed
595
                Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
596
597
                more information about each option see [designing a device
                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
598
            max_memory (`Dict`, *optional*):
Steven Liu's avatar
Steven Liu committed
599
600
                A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
                each GPU and the available CPU RAM if unset.
601
            offload_folder (`str` or `os.PathLike`, *optional*):
Steven Liu's avatar
Steven Liu committed
602
                The path to offload weights if device_map contains the value `"disk"`.
603
            offload_state_dict (`bool`, *optional*):
Steven Liu's avatar
Steven Liu committed
604
605
606
                If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
                the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
                when there is some disk offload.
607
            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Steven Liu's avatar
Steven Liu committed
608
609
610
611
                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
                argument to `True` will raise an error.
612
            use_safetensors (`bool`, *optional*, defaults to `None`):
Steven Liu's avatar
Steven Liu committed
613
614
615
                If set to `None`, the safetensors weights are downloaded if they're available **and** if the
                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
                weights. If set to `False`, safetensors weights are not loaded.
616
617
618
619
620
            use_onnx (`bool`, *optional*, defaults to `None`):
                If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
                will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
                `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
                with `.onnx` and `.pb`.
621
            kwargs (remaining dictionary of keyword arguments, *optional*):
Steven Liu's avatar
Steven Liu committed
622
623
624
                Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
                class). The overwritten components are passed directly to the pipelines `__init__` method. See example
                below for more information.
625
            variant (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
626
627
                Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
                loading `from_flax`.
628
629
630

        <Tip>

Steven Liu's avatar
Steven Liu committed
631
632
        To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
        `huggingface-cli login`.
633
634
635
636
637
638
639
640
641
642
643
644
645
646

        </Tip>

        Examples:

        ```py
        >>> from diffusers import DiffusionPipeline

        >>> # Download pipeline from huggingface.co and cache.
        >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")

        >>> # Download pipeline that requires an authorization token
        >>> # For more information on access tokens, please refer to this section
        >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
647
        >>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
648
649
650
651
652
653
654
655

        >>> # Use a different scheduler
        >>> from diffusers import LMSDiscreteScheduler

        >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
        >>> pipeline.scheduler = scheduler
        ```
        """
656
657
658
        # Copy the kwargs to re-use during loading connected pipeline.
        kwargs_copied = kwargs.copy()

659
        cache_dir = kwargs.pop("cache_dir", None)
660
661
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
662
663
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
664
        revision = kwargs.pop("revision", None)
665
        from_flax = kwargs.pop("from_flax", False)
666
667
668
669
670
671
        torch_dtype = kwargs.pop("torch_dtype", None)
        custom_pipeline = kwargs.pop("custom_pipeline", None)
        custom_revision = kwargs.pop("custom_revision", None)
        provider = kwargs.pop("provider", None)
        sess_options = kwargs.pop("sess_options", None)
        device_map = kwargs.pop("device_map", None)
672
673
674
        max_memory = kwargs.pop("max_memory", None)
        offload_folder = kwargs.pop("offload_folder", None)
        offload_state_dict = kwargs.pop("offload_state_dict", False)
675
        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
676
        variant = kwargs.pop("variant", None)
677
        use_safetensors = kwargs.pop("use_safetensors", None)
678
        use_onnx = kwargs.pop("use_onnx", None)
679
        load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
680

681
682
683
684
685
686
687
688
689
        if low_cpu_mem_usage and not is_accelerate_available():
            low_cpu_mem_usage = False
            logger.warning(
                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                " install accelerate\n```\n."
            )

690
691
692
693
694
695
        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

696
697
698
699
700
701
        if device_map is not None and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `device_map=None`."
            )

702
        if device_map is not None and not is_accelerate_available():
703
            raise NotImplementedError(
704
705
706
707
708
709
710
711
712
                "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
            )

        if device_map is not None and not isinstance(device_map, str):
            raise ValueError("`device_map` must be a string.")

        if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
            raise NotImplementedError(
                f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
713
714
            )

715
716
717
718
        if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
            if is_accelerate_version("<", "0.28.0"):
                raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")

719
720
721
722
723
724
        if low_cpu_mem_usage is False and device_map is not None:
            raise ValueError(
                f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
                " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
            )

725
726
727
        # 1. Download the checkpoints and configs
        # use snapshot download here to get it working from from_pretrained
        if not os.path.isdir(pretrained_model_name_or_path):
Patrick von Platen's avatar
Patrick von Platen committed
728
729
730
731
732
            if pretrained_model_name_or_path.count("/") > 1:
                raise ValueError(
                    f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
                    " is neither a valid local path nor a valid repo id. Please check the parameter."
                )
733
            cached_folder = cls.download(
734
735
736
737
738
                pretrained_model_name_or_path,
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
739
                token=token,
740
                revision=revision,
741
                from_flax=from_flax,
742
                use_safetensors=use_safetensors,
743
                use_onnx=use_onnx,
744
                custom_pipeline=custom_pipeline,
745
                custom_revision=custom_revision,
746
                variant=variant,
747
                load_connected_pipeline=load_connected_pipeline,
748
                **kwargs,
749
750
751
752
            )
        else:
            cached_folder = pretrained_model_name_or_path

753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
        # The variant filenames can have the legacy sharding checkpoint format that we check and throw
        # a warning if detected.
        if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
            warn_msg = (
                f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
                "Please check your files carefully:\n\n"
                "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
                "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
                "If you find any files in the deprecated format:\n"
                "1. Remove all existing checkpoint files for this variant.\n"
                "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
                "This will ensure you're using the most up-to-date and compatible checkpoint format."
            )
            logger.warning(warn_msg)

768
769
        config_dict = cls.load_config(cached_folder)

Patrick von Platen's avatar
Patrick von Platen committed
770
771
772
        # pop out "_ignore_files" as it is only needed for download
        config_dict.pop("_ignore_files", None)

773
        # 2. Define which model components should load variants
774
775
776
777
        # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
        # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
        # with variant being `"fp16"`.
        model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
778
779
780
        if len(model_variants) == 0 and variant is not None:
            error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
            raise ValueError(error_message)
781

782
        # 3. Load the pipeline class, if using custom module then load it from the hub
783
        # if we load from explicit class, let's use it
784
785
786
        custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
            folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
        )
787
        pipeline_class = _get_pipeline_class(
788
            cls,
789
            config=config_dict,
790
791
            load_connected_pipeline=load_connected_pipeline,
            custom_pipeline=custom_pipeline,
792
            class_name=custom_class_name,
793
794
            cache_dir=cache_dir,
            revision=custom_revision,
795
        )
796

797
798
799
        if device_map is not None and pipeline_class._load_connected_pipes:
            raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")

800
        # DEPRECATED: To be removed in 1.0.0
801
802
803
804
805
806
807
        # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
        # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
        _maybe_raise_warning_for_inpainting(
            pipeline_class=pipeline_class,
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            config=config_dict,
        )
808

809
810
811
        # 4. Define expected modules given pipeline signature
        # and define non-None initialized modules (=`init_kwargs`)

812
813
814
815
        # some modules can be passed directly to the init
        # in this case they are already instantiated in `kwargs`
        # extract them here
        expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
816
        expected_types = pipeline_class._get_signature_types()
817
818
819
820
        passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
        passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
        init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)

821
822
823
824
825
826
        # define init kwargs and make sure that optional component modules are filtered out
        init_kwargs = {
            k: init_dict.pop(k)
            for k in optional_kwargs
            if k in init_dict and k not in pipeline_class._optional_components
        }
827
828
829
830
831
832
833
834
835
836
837
838
        init_kwargs = {**init_kwargs, **passed_pipe_kwargs}

        # remove `null` components
        def load_module(name, value):
            if value[0] is None:
                return False
            if name in passed_class_obj and passed_class_obj[name] is None:
                return False
            return True

        init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        for key in init_dict.keys():
            if key not in passed_class_obj:
                continue
            if "scheduler" in key:
                continue

            class_obj = passed_class_obj[key]
            _expected_class_types = []
            for expected_type in expected_types[key]:
                if isinstance(expected_type, enum.EnumMeta):
                    _expected_class_types.extend(expected_type.__members__.keys())
                else:
                    _expected_class_types.append(expected_type.__name__)

            _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
            if not _is_valid_type:
                logger.warning(
                    f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
                )

859
860
861
862
863
864
865
866
        # Special case: safety_checker must be loaded separately when using `from_flax`
        if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
            raise NotImplementedError(
                "The safety checker cannot be automatically loaded when loading weights `from_flax`."
                " Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
                " separately if you need it."
            )

867
        # 5. Throw nice warnings / errors for fast accelerate loading
868
869
870
871
872
873
874
875
        if len(unused_kwargs) > 0:
            logger.warning(
                f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
            )

        # import it here to avoid circular import
        from diffusers import pipelines

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        # 6. device map delegation
        final_device_map = None
        if device_map is not None:
            final_device_map = _get_final_device_map(
                device_map=device_map,
                pipeline_class=pipeline_class,
                passed_class_obj=passed_class_obj,
                init_dict=init_dict,
                library=library,
                max_memory=max_memory,
                torch_dtype=torch_dtype,
                cached_folder=cached_folder,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
            )

        # 7. Load each module in the pipeline
        current_device_map = None
897
        for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
898
            # 7.1 device_map shenanigans
899
900
901
902
903
904
905
            if final_device_map is not None and len(final_device_map) > 0:
                component_device = final_device_map.get(name, None)
                if component_device is not None:
                    current_device_map = {"": component_device}
                else:
                    current_device_map = None

906
            # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
907
            class_name = class_name[4:] if class_name.startswith("Flax") else class_name
908

909
            # 7.3 Define all importable classes
910
            is_pipeline_module = hasattr(pipelines, library_name)
911
            importable_classes = ALL_IMPORTABLE_CLASSES
912
913
            loaded_sub_model = None

914
            # 7.4 Use passed sub model or load class_name from library_name
915
            if name in passed_class_obj:
916
917
918
919
920
                # if the model is in a pipeline module, then we load it from the pipeline
                # check that passed_class_obj has correct parent class
                maybe_raise_or_warn(
                    library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
                )
921
922
923

                loaded_sub_model = passed_class_obj[name]
            else:
924
925
926
927
928
929
930
931
932
933
934
                # load sub model
                loaded_sub_model = load_sub_model(
                    library_name=library_name,
                    class_name=class_name,
                    importable_classes=importable_classes,
                    pipelines=pipelines,
                    is_pipeline_module=is_pipeline_module,
                    pipeline_class=pipeline_class,
                    torch_dtype=torch_dtype,
                    provider=provider,
                    sess_options=sess_options,
935
                    device_map=current_device_map,
936
937
938
                    max_memory=max_memory,
                    offload_folder=offload_folder,
                    offload_state_dict=offload_state_dict,
939
940
941
942
943
944
                    model_variants=model_variants,
                    name=name,
                    from_flax=from_flax,
                    variant=variant,
                    low_cpu_mem_usage=low_cpu_mem_usage,
                    cached_folder=cached_folder,
945
                    use_safetensors=use_safetensors,
946
                )
947
948
949
                logger.info(
                    f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
                )
950
951
952

            init_kwargs[name] = loaded_sub_model  # UNet(...), # DiffusionSchedule(...)

953
        # 8. Handle connected pipelines.
954
        if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
955
956
957
958
959
960
961
            init_kwargs = _update_init_kwargs_with_connected_pipeline(
                init_kwargs=init_kwargs,
                passed_pipe_kwargs=passed_pipe_kwargs,
                passed_class_objs=passed_class_obj,
                folder=cached_folder,
                **kwargs_copied,
            )
962

963
        # 9. Potentially add passed objects if expected
964
965
966
967
968
969
970
971
972
973
974
975
        missing_modules = set(expected_modules) - set(init_kwargs.keys())
        passed_modules = list(passed_class_obj.keys())
        optional_modules = pipeline_class._optional_components
        if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
            for module in missing_modules:
                init_kwargs[module] = passed_class_obj.get(module, None)
        elif len(missing_modules) > 0:
            passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
            raise ValueError(
                f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
            )

976
        # 10. Instantiate the pipeline
977
        model = pipeline_class(**init_kwargs)
978

979
        # 11. Save where the model was instantiated from
980
        model.register_to_config(_name_or_path=pretrained_model_name_or_path)
981
982
        if device_map is not None:
            setattr(model, "hf_device_map", final_device_map)
983
984
        return model

985
986
987
988
    @property
    def name_or_path(self) -> str:
        return getattr(self.config, "_name_or_path", None)

989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
    @property
    def _execution_device(self):
        r"""
        Returns the device on which the pipeline's models will be executed. After calling
        [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
        Accelerate's module hooks.
        """
        for name, model in self.components.items():
            if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
                continue

            if not hasattr(model, "_hf_hook"):
                return self.device
            for module in model.modules():
                if (
                    hasattr(module, "_hf_hook")
                    and hasattr(module._hf_hook, "execution_device")
                    and module._hf_hook.execution_device is not None
                ):
                    return torch.device(module._hf_hook.execution_device)
        return self.device

1011
1012
1013
1014
1015
1016
    def remove_all_hooks(self):
        r"""
        Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
        """
        for _, model in self.components.items():
            if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
1017
                accelerate.hooks.remove_hook_from_module(model, recurse=True)
1018
1019
        self._all_hooks = []

1020
    def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1021
1022
1023
1024
1025
        r"""
        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1026
1027
1028
1029
1030
1031
1032

        Arguments:
            gpu_id (`int`, *optional*):
                The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
            device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
                The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
                default to "cuda".
1033
        """
1034
1035
1036
1037
1038
1039
        is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
        if is_pipeline_device_mapped:
            raise ValueError(
                "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
            )

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
        if self.model_cpu_offload_seq is None:
            raise ValueError(
                "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
            )

        if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
            from accelerate import cpu_offload_with_hook
        else:
            raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")

1050
1051
        self.remove_all_hooks()

1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        torch_device = torch.device(device)
        device_index = torch_device.index

        if gpu_id is not None and device_index is not None:
            raise ValueError(
                f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
                f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
            )

        # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1062
        self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
1063
1064
1065

        device_type = torch_device.type
        device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1066
        self._offload_device = device
1067

1068
1069
1070
1071
        self.to("cpu", silence_dtype_warnings=True)
        device_mod = getattr(torch, device.type, None)
        if hasattr(device_mod, "empty_cache") and device_mod.is_available():
            device_mod.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)
1072
1073
1074

        all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}

1075
        self._all_hooks = []
1076
1077
        hook = None
        for model_str in self.model_cpu_offload_seq.split("->"):
1078
            model = all_model_components.pop(model_str, None)
1079

1080
1081
1082
            if not isinstance(model, torch.nn.Module):
                continue

1083
1084
1085
1086
1087
1088
1089
1090
            # This is because the model would already be placed on a CUDA device.
            _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
            if is_loaded_in_8bit_bnb:
                logger.info(
                    f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
                )
                continue

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
            _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
            self._all_hooks.append(hook)

        # CPU offload models that are not in the seq chain unless they are explicitly excluded
        # these models will stay on CPU until maybe_free_model_hooks is called
        # some models cannot be in the seq chain because they are iteratively called, such as controlnet
        for name, model in all_model_components.items():
            if not isinstance(model, torch.nn.Module):
                continue

            if name in self._exclude_from_cpu_offload:
                model.to(device)
            else:
                _, hook = cpu_offload_with_hook(model, device)
                self._all_hooks.append(hook)

    def maybe_free_model_hooks(self):
        r"""
1109
1110
1111
1112
        Function that offloads all components, removes all model hooks that were added when using
        `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
        is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
        functions correctly when applying enable_model_cpu_offload.
1113
1114
1115
1116
1117
1118
        """
        if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
            # `enable_model_cpu_offload` has not be called, so silently do nothing
            return

        # make sure the model is in the same state as before calling it
1119
        self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
1120

1121
    def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1122
        r"""
1123
1124
1125
1126
        Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
        dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
        and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
        method called. Offloading happens on a submodule basis. Memory savings are higher than with
1127
        `enable_model_cpu_offload`, but performance is lower.
1128
1129
1130
1131
1132
1133
1134

        Arguments:
            gpu_id (`int`, *optional*):
                The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
            device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
                The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
                default to "cuda".
1135
1136
1137
1138
1139
        """
        if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
            from accelerate import cpu_offload
        else:
            raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
1140
        self.remove_all_hooks()
1141

1142
1143
1144
1145
1146
1147
        is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
        if is_pipeline_device_mapped:
            raise ValueError(
                "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
            )

1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        torch_device = torch.device(device)
        device_index = torch_device.index

        if gpu_id is not None and device_index is not None:
            raise ValueError(
                f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
                f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
            )

        # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
1158
        self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
1159
1160
1161

        device_type = torch_device.type
        device = torch.device(f"{device_type}:{self._offload_gpu_id}")
1162
        self._offload_device = device
1163
1164
1165

        if self.device.type != "cpu":
            self.to("cpu", silence_dtype_warnings=True)
1166
1167
1168
            device_mod = getattr(torch, self.device.type, None)
            if hasattr(device_mod, "empty_cache") and device_mod.is_available():
                device_mod.empty_cache()  # otherwise we don't see the memory savings (but they probably exist)
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181

        for name, model in self.components.items():
            if not isinstance(model, torch.nn.Module):
                continue

            if name in self._exclude_from_cpu_offload:
                model.to(device)
            else:
                # make sure to offload buffers if not all high level weights
                # are of type nn.Module
                offload_buffers = len(model._parameters) > 0
                cpu_offload(model, device, offload_buffers=offload_buffers)

1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    def reset_device_map(self):
        r"""
        Resets the device maps (if any) to None.
        """
        if self.hf_device_map is None:
            return
        else:
            self.remove_all_hooks()
            for name, component in self.components.items():
                if isinstance(component, torch.nn.Module):
                    component.to("cpu")
            self.hf_device_map = None

1195
    @classmethod
1196
    @validate_hf_hub_args
1197
1198
    def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
        r"""
Steven Liu's avatar
Steven Liu committed
1199
        Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights.
1200
1201

        Parameters:
Steven Liu's avatar
Steven Liu committed
1202
            pretrained_model_name (`str` or `os.PathLike`, *optional*):
Steven Liu's avatar
Steven Liu committed
1203
                A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
Steven Liu's avatar
Steven Liu committed
1204
                hosted on the Hub.
1205
1206
1207
            custom_pipeline (`str`, *optional*):
                Can be either:

Steven Liu's avatar
Steven Liu committed
1208
                    - A string, the *repository id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained
Steven Liu's avatar
Steven Liu committed
1209
1210
                      pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines
                      the custom pipeline.
1211
1212

                    - A string, the *file name* of a community pipeline hosted on GitHub under
Steven Liu's avatar
Steven Liu committed
1213
1214
1215
1216
                      [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file
                      names must match the file name and not the pipeline script (`clip_guided_stable_diffusion`
                      instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the
                      current `main` branch of GitHub.
1217

Steven Liu's avatar
Steven Liu committed
1218
1219
                    - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
                      must contain a file called `pipeline.py` that defines the custom pipeline.
1220

Steven Liu's avatar
Steven Liu committed
1221
                <Tip warning={true}>
1222

Steven Liu's avatar
Steven Liu committed
1223
                🧪 This is an experimental feature and may change in the future.
1224

Steven Liu's avatar
Steven Liu committed
1225
                </Tip>
1226

Steven Liu's avatar
Steven Liu committed
1227
1228
                For more information on how to load and create custom pipelines, take a look at [How to contribute a
                community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
1229
1230
1231
1232

            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
1233

1234
            proxies (`Dict[str, str]`, *optional*):
Steven Liu's avatar
Steven Liu committed
1235
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1236
1237
1238
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
Steven Liu's avatar
Steven Liu committed
1239
1240
1241
            local_files_only (`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
1242
            token (`str` or *bool*, *optional*):
Steven Liu's avatar
Steven Liu committed
1243
1244
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
1245
            revision (`str`, *optional*, defaults to `"main"`):
Steven Liu's avatar
Steven Liu committed
1246
1247
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
Steven Liu's avatar
Steven Liu committed
1248
            custom_revision (`str`, *optional*, defaults to `"main"`):
1249
                The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
Steven Liu's avatar
Steven Liu committed
1250
1251
                `revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
                custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
1252
            mirror (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
1253
1254
1255
                Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
                guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
                information.
1256
            variant (`str`, *optional*):
Steven Liu's avatar
Steven Liu committed
1257
1258
                Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
                loading `from_flax`.
1259
1260
1261
1262
1263
1264
1265
1266
1267
            use_safetensors (`bool`, *optional*, defaults to `None`):
                If set to `None`, the safetensors weights are downloaded if they're available **and** if the
                safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
                weights. If set to `False`, safetensors weights are not loaded.
            use_onnx (`bool`, *optional*, defaults to `False`):
                If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
                will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
                `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
                with `.onnx` and `.pb`.
1268
1269
1270
1271
            trust_remote_code (`bool`, *optional*, defaults to `False`):
                Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
                option should only be set to `True` for repositories you trust and in which you have read the code, as
                it will execute code present on the Hub on your local machine.
Steven Liu's avatar
Steven Liu committed
1272
1273
1274
1275

        Returns:
            `os.PathLike`:
                A path to the downloaded pipeline.
1276
1277
1278

        <Tip>

Steven Liu's avatar
Steven Liu committed
1279
1280
        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
        `huggingface-cli login`.
1281
1282
1283
1284

        </Tip>

        """
1285
        cache_dir = kwargs.pop("cache_dir", None)
1286
1287
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
1288
1289
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
1290
1291
1292
        revision = kwargs.pop("revision", None)
        from_flax = kwargs.pop("from_flax", False)
        custom_pipeline = kwargs.pop("custom_pipeline", None)
1293
        custom_revision = kwargs.pop("custom_revision", None)
1294
        variant = kwargs.pop("variant", None)
1295
        use_safetensors = kwargs.pop("use_safetensors", None)
1296
        use_onnx = kwargs.pop("use_onnx", None)
1297
        load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
1298
        trust_remote_code = kwargs.pop("trust_remote_code", False)
1299
1300
1301

        allow_pickle = False
        if use_safetensors is None:
1302
            use_safetensors = True
1303
            allow_pickle = True
1304
1305
1306
1307

        allow_patterns = None
        ignore_patterns = None

1308
        model_info_call_error: Optional[Exception] = None
1309
1310
        if not local_files_only:
            try:
1311
                info = model_info(pretrained_model_name, token=token, revision=revision)
1312
            except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
1313
                logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
1314
                local_files_only = True
1315
                model_info_call_error = e  # save error to reraise it if model is not cached locally
1316

1317
        if not local_files_only:
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
            filenames = {sibling.rfilename for sibling in info.siblings}
            if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
                warn_msg = (
                    f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
                    "Please check your files carefully:\n\n"
                    "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
                    "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
                    "If you find any files in the deprecated format:\n"
                    "1. Remove all existing checkpoint files for this variant.\n"
                    "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
                    "This will ensure you're using the most up-to-date and compatible checkpoint format."
                )
                logger.warning(warn_msg)

            model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

1334
1335
1336
1337
            config_file = hf_hub_download(
                pretrained_model_name,
                cls.config_name,
                cache_dir=cache_dir,
1338
                revision=revision,
1339
1340
                proxies=proxies,
                force_download=force_download,
1341
                token=token,
1342
1343
1344
            )

            config_dict = cls._dict_from_json_file(config_file)
Patrick von Platen's avatar
Patrick von Platen committed
1345
1346
1347
1348
1349
1350
            ignore_filenames = config_dict.pop("_ignore_files", [])

            # remove ignored filenames
            model_filenames = set(model_filenames) - set(ignore_filenames)
            variant_filenames = set(variant_filenames) - set(ignore_filenames)

1351
1352
            if revision in DEPRECATED_REVISION_ARGS and version.parse(
                version.parse(__version__).base_version
Patrick von Platen's avatar
Patrick von Platen committed
1353
            ) >= version.parse("0.22.0"):
1354
                warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1355

1356
1357
1358
            custom_components, folder_names = _get_custom_components_and_folders(
                pretrained_model_name, config_dict, filenames, variant_filenames, variant
            )
1359
            model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1360

1361
1362
1363
1364
1365
            custom_class_name = None
            if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
                custom_pipeline = config_dict["_class_name"][0]
                custom_class_name = config_dict["_class_name"][1]

1366
1367
1368
1369
1370
            # all filenames compatible with variant will be added
            allow_patterns = list(model_filenames)

            # allow all patterns from non-model folders
            # this enables downloading schedulers, tokenizers, ...
1371
            allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
1372
1373
1374
1375
            # add custom component files
            allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
            # add custom pipeline file
            allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
1376
            # also allow downloading config.json files with the model
1377
            allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
1378
1379
1380
1381
1382
1383
1384
1385

            allow_patterns += [
                SCHEDULER_CONFIG_NAME,
                CONFIG_NAME,
                cls.config_name,
                CUSTOM_PIPELINE_FILE_NAME,
            ]

1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
            load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
            load_components_from_hub = len(custom_components) > 0

            if load_pipe_from_hub and not trust_remote_code:
                raise ValueError(
                    f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
                    f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
                    f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
                )

            if load_components_from_hub and not trust_remote_code:
                raise ValueError(
                    f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
                    f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
                    f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
                )

1403
1404
            # retrieve passed components that should not be downloaded
            pipeline_class = _get_pipeline_class(
1405
1406
1407
1408
                cls,
                config_dict,
                load_connected_pipeline=load_connected_pipeline,
                custom_pipeline=custom_pipeline,
1409
1410
1411
                repo_id=pretrained_model_name if load_pipe_from_hub else None,
                hub_revision=revision,
                class_name=custom_class_name,
1412
1413
                cache_dir=cache_dir,
                revision=custom_revision,
1414
1415
1416
1417
            )
            expected_components, _ = cls._get_signature_keys(pipeline_class)
            passed_components = [k for k in expected_components if k in kwargs]

1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
            # retrieve all patterns that should not be downloaded and error out when needed
            ignore_patterns = _get_ignore_patterns(
                passed_components,
                model_folder_names,
                model_filenames,
                variant_filenames,
                use_safetensors,
                from_flax,
                allow_pickle,
                use_onnx,
                pipeline_class._is_onnx,
                variant,
            )
1431

1432
1433
1434
1435
            # Don't download any objects that are passed
            allow_patterns = [
                p for p in allow_patterns if not (len(p.split("/")) == 2 and p.split("/")[0] in passed_components)
            ]
1436
1437
1438
1439

            if pipeline_class._load_connected_pipes:
                allow_patterns.append("README.md")

1440
1441
            # Don't download index files of forbidden patterns either
            ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
1442
1443
1444
1445
1446
            re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
            re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]

            expected_files = [f for f in filenames if not any(p.match(f) for p in re_ignore_pattern)]
            expected_files = [f for f in expected_files if any(p.match(f) for p in re_allow_pattern)]
1447

1448
1449
            snapshot_folder = Path(config_file).parent
            pipeline_is_cached = all((snapshot_folder / f).is_file() for f in expected_files)
1450

1451
            if pipeline_is_cached and not force_download:
1452
1453
1454
                # if the pipeline is cached, we can directly return it
                # else call snapshot_download
                return snapshot_folder
1455

1456
1457
1458
        user_agent = {"pipeline_class": cls.__name__}
        if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
            user_agent["custom_pipeline"] = custom_pipeline
1459
1460

        # download all allow_patterns - ignore_patterns
1461
        try:
1462
            cached_folder = snapshot_download(
1463
1464
1465
1466
                pretrained_model_name,
                cache_dir=cache_dir,
                proxies=proxies,
                local_files_only=local_files_only,
1467
                token=token,
1468
1469
1470
1471
1472
                revision=revision,
                allow_patterns=allow_patterns,
                ignore_patterns=ignore_patterns,
                user_agent=user_agent,
            )
1473

1474
1475
            # retrieve pipeline class from local file
            cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
1476
            cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
1477

1478
1479
            diffusers_module = importlib.import_module(__name__.split(".")[0])
            pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
1480
1481

            if pipeline_class is not None and pipeline_class._load_connected_pipes:
1482
1483
1484
                modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
                connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
                for connected_pipe_repo_id in connected_pipes:
1485
1486
1487
1488
1489
                    download_kwargs = {
                        "cache_dir": cache_dir,
                        "force_download": force_download,
                        "proxies": proxies,
                        "local_files_only": local_files_only,
1490
                        "token": token,
1491
1492
1493
1494
                        "variant": variant,
                        "use_safetensors": use_safetensors,
                    }
                    DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
1495
1496
1497

            return cached_folder

1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
        except FileNotFoundError:
            # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
            # This can happen in two cases:
            # 1. If the user passed `local_files_only=True`                    => we raise the error directly
            # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
            if model_info_call_error is None:
                # 1. user passed `local_files_only=True`
                raise
            else:
                # 2. we forced `local_files_only=True` when `model_info` failed
                raise EnvironmentError(
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
1509
                    f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
1510
1511
1512
                    " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
                    " above."
                ) from model_info_call_error
1513

1514
1515
    @classmethod
    def _get_signature_keys(cls, obj):
1516
1517
1518
        parameters = inspect.signature(obj.__init__).parameters
        required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
        optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
1519
        expected_modules = set(required_parameters.keys()) - {"self"}
1520
1521
1522
1523
1524
1525
1526

        optional_names = list(optional_parameters)
        for name in optional_names:
            if name in cls._optional_components:
                expected_modules.add(name)
                optional_parameters.remove(name)

1527
1528
        return expected_modules, optional_parameters

1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
    @classmethod
    def _get_signature_types(cls):
        signature_types = {}
        for k, v in inspect.signature(cls.__init__).parameters.items():
            if inspect.isclass(v.annotation):
                signature_types[k] = (v.annotation,)
            elif get_origin(v.annotation) == Union:
                signature_types[k] = get_args(v.annotation)
            else:
                logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
        return signature_types

1541
1542
1543
1544
    @property
    def components(self) -> Dict[str, Any]:
        r"""
        The `self.components` property can be useful to run different pipelines with the same weights and
Steven Liu's avatar
Steven Liu committed
1545
1546
1547
1548
        configurations without reallocating additional memory.

        Returns (`dict`):
            A dictionary containing all the modules needed to initialize the pipeline.
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558

        Examples:

        ```py
        >>> from diffusers import (
        ...     StableDiffusionPipeline,
        ...     StableDiffusionImg2ImgPipeline,
        ...     StableDiffusionInpaintPipeline,
        ... )

1559
        >>> text2img = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
        >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
        >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
        ```
        """
        expected_modules, optional_parameters = self._get_signature_keys(self)
        components = {
            k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
        }

        if set(components.keys()) != expected_modules:
            raise ValueError(
                f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
1572
                f" {expected_modules} to be defined, but {components.keys()} are defined."
1573
1574
1575
1576
1577
1578
1579
            )

        return components

    @staticmethod
    def numpy_to_pil(images):
        """
Steven Liu's avatar
Steven Liu committed
1580
        Convert a NumPy image or a batch of images to a PIL image.
1581
        """
Patrick von Platen's avatar
Patrick von Platen committed
1582
        return numpy_to_pil(images)
1583

lsb's avatar
lsb committed
1584
    @torch.compiler.disable
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
    def progress_bar(self, iterable=None, total=None):
        if not hasattr(self, "_progress_bar_config"):
            self._progress_bar_config = {}
        elif not isinstance(self._progress_bar_config, dict):
            raise ValueError(
                f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
            )

        if iterable is not None:
            return tqdm(iterable, **self._progress_bar_config)
        elif total is not None:
            return tqdm(total=total, **self._progress_bar_config)
        else:
            raise ValueError("Either `total` or `iterable` has to be defined.")

    def set_progress_bar_config(self, **kwargs):
        self._progress_bar_config = kwargs

1603
    def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
1604
        r"""
1605
1606
1607
        Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this
        option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
        up during training is not guaranteed.
1608

Steven Liu's avatar
Steven Liu committed
1609
        <Tip warning={true}>
1610

Steven Liu's avatar
Steven Liu committed
1611
1612
1613
1614
        ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
        precedent.

        </Tip>
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634

        Parameters:
            attention_op (`Callable`, *optional*):
                Override the default `None` operator for use as `op` argument to the
                [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
                function of xFormers.

        Examples:

        ```py
        >>> import torch
        >>> from diffusers import DiffusionPipeline
        >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp

        >>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")
        >>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
        >>> # Workaround for not accepting attention shape using VAE for Flash Attention
        >>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
        ```
1635
        """
1636
        self.set_use_memory_efficient_attention_xformers(True, attention_op)
1637
1638
1639

    def disable_xformers_memory_efficient_attention(self):
        r"""
Steven Liu's avatar
Steven Liu committed
1640
        Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
1641
1642
1643
        """
        self.set_use_memory_efficient_attention_xformers(False)

1644
1645
1646
    def set_use_memory_efficient_attention_xformers(
        self, valid: bool, attention_op: Optional[Callable] = None
    ) -> None:
1647
1648
1649
1650
1651
        # Recursively walk through all the children.
        # Any children which exposes the set_use_memory_efficient_attention_xformers method
        # gets the message
        def fn_recursive_set_mem_eff(module: torch.nn.Module):
            if hasattr(module, "set_use_memory_efficient_attention_xformers"):
1652
                module.set_use_memory_efficient_attention_xformers(valid, attention_op)
1653
1654
1655
1656

            for child in module.children():
                fn_recursive_set_mem_eff(child)

1657
1658
1659
        module_names, _ = self._get_signature_keys(self)
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module)]
1660

1661
1662
        for module in modules:
            fn_recursive_set_mem_eff(module)
1663
1664
1665

    def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
        r"""
1666
        Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
        in slices to compute attention in several steps. For more than one attention head, the computation is performed
        sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.

        <Tip warning={true}>

        ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
        2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
        this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!

        </Tip>
1677
1678
1679
1680

        Args:
            slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
Alexander Pivovarov's avatar
Alexander Pivovarov committed
1681
                `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
1682
1683
                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
                must be a multiple of `slice_size`.
1684
1685
1686
1687
1688
1689
1690
1691

        Examples:

        ```py
        >>> import torch
        >>> from diffusers import StableDiffusionPipeline

        >>> pipe = StableDiffusionPipeline.from_pretrained(
1692
        ...     "stable-diffusion-v1-5/stable-diffusion-v1-5",
1693
1694
1695
1696
1697
1698
1699
1700
        ...     torch_dtype=torch.float16,
        ...     use_safetensors=True,
        ... )

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> pipe.enable_attention_slicing()
        >>> image = pipe(prompt).images[0]
        ```
1701
1702
1703
1704
1705
        """
        self.set_attention_slice(slice_size)

    def disable_attention_slicing(self):
        r"""
Steven Liu's avatar
Steven Liu committed
1706
1707
        Disable sliced attention computation. If `enable_attention_slicing` was previously called, attention is
        computed in one step.
1708
1709
1710
1711
1712
        """
        # set slice_size = `None` to disable `attention slicing`
        self.enable_attention_slicing(None)

    def set_attention_slice(self, slice_size: Optional[int]):
1713
1714
1715
        module_names, _ = self._get_signature_keys(self)
        modules = [getattr(self, n, None) for n in module_names]
        modules = [m for m in modules if isinstance(m, torch.nn.Module) and hasattr(m, "set_attention_slice")]
1716

1717
1718
        for module in modules:
            module.set_attention_slice(slice_size)
1719

1720
1721
1722
    @classmethod
    def from_pipe(cls, pipeline, **kwargs):
        r"""
1723
1724
        Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
        pipeline components without reallocating additional memory.
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738

        Arguments:
            pipeline (`DiffusionPipeline`):
                The pipeline from which to create a new pipeline.

        Returns:
            `DiffusionPipeline`:
                A new pipeline with the same weights and configurations as `pipeline`.

        Examples:

        ```py
        >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline

1739
        >>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
        >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
        ```
        """

        original_config = dict(pipeline.config)
        torch_dtype = kwargs.pop("torch_dtype", None)

        # derive the pipeline class to instantiate
        custom_pipeline = kwargs.pop("custom_pipeline", None)
        custom_revision = kwargs.pop("custom_revision", None)

        if custom_pipeline is not None:
            pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
        else:
            pipeline_class = cls

        expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
        # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
        # e.g. `image_encoder` for StableDiffusionPipeline
        parameters = inspect.signature(cls.__init__).parameters
        true_optional_modules = set(
            {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
        )

        # get the class of each component based on its type hint
        # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
        component_types = pipeline_class._get_signature_types()

        pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
        # allow users pass modules in `kwargs` to override the original pipeline's components
        passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}

        original_class_obj = {}
        for name, component in pipeline.components.items():
            if name in expected_modules and name not in passed_class_obj:
                # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
                if (
                    not isinstance(component, ModelMixin)
                    or type(component) in component_types[name]
                    or (component is None and name in cls._optional_components)
                ):
                    original_class_obj[name] = component
                else:
1783
                    logger.warning(
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
                        f"component {name} is not switched over to new pipeline because type does not match the expected."
                        f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
                        f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
                    )

        # allow users pass optional kwargs to override the original pipelines config attribute
        passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
        original_pipe_kwargs = {
            k: original_config[k]
            for k in original_config.keys()
            if k in optional_kwargs and k not in passed_pipe_kwargs
        }

        # config attribute that were not expected by pipeline is stored as its private attribute
        # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
        # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
        additional_pipe_kwargs = [
            k[1:]
            for k in original_config.keys()
            if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
        ]
        for k in additional_pipe_kwargs:
            original_pipe_kwargs[k] = original_config.pop(f"_{k}")

        pipeline_kwargs = {
            **passed_class_obj,
            **original_class_obj,
            **passed_pipe_kwargs,
            **original_pipe_kwargs,
            **kwargs,
        }

        # store unused config as private attribute in the new pipeline
        unused_original_config = {
            f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
        }

        missing_modules = (
            set(expected_modules)
            - set(pipeline._optional_components)
            - set(pipeline_kwargs.keys())
            - set(true_optional_modules)
        )

        if len(missing_modules) > 0:
            raise ValueError(
                f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
            )

        new_pipeline = pipeline_class(**pipeline_kwargs)
        if pretrained_model_name_or_path is not None:
            new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
        new_pipeline.register_to_config(**unused_original_config)

        if torch_dtype is not None:
            new_pipeline.to(dtype=torch_dtype)

        return new_pipeline

1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905

class StableDiffusionMixin:
    r"""
    Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
    """

    def enable_vae_slicing(self):
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.vae.enable_slicing()

    def disable_vae_slicing(self):
        r"""
        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_slicing()

    def enable_vae_tiling(self):
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.
        """
        self.vae.enable_tiling()

    def disable_vae_tiling(self):
        r"""
        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
        computing decoding in one step.
        """
        self.vae.disable_tiling()

    def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
        r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.

        The suffixes after the scaling factors represent the stages where they are being applied.

        Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
        that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.

        Args:
            s1 (`float`):
                Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
                mitigate "oversmoothing effect" in the enhanced denoising process.
            s2 (`float`):
                Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
                mitigate "oversmoothing effect" in the enhanced denoising process.
            b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
            b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
        """
        if not hasattr(self, "unet"):
            raise ValueError("The pipeline must have `unet` for using FreeU.")
        self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)

    def disable_freeu(self):
        """Disables the FreeU mechanism if enabled."""
        self.unet.disable_freeu()

    def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
        """
1906
1907
        Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
        are fused. For cross-attention modules, key and value projection matrices are fused.
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        Args:
            unet (`bool`, defaults to `True`): To apply fusion on the UNet.
            vae (`bool`, defaults to `True`): To apply fusion on the VAE.
        """
        self.fusing_unet = False
        self.fusing_vae = False

        if unet:
            self.fusing_unet = True
            self.unet.fuse_qkv_projections()
            self.unet.set_attn_processor(FusedAttnProcessor2_0())

        if vae:
            if not isinstance(self.vae, AutoencoderKL):
                raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")

            self.fusing_vae = True
            self.vae.fuse_qkv_projections()
            self.vae.set_attn_processor(FusedAttnProcessor2_0())

    def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
        """Disable QKV projection fusion if enabled.

        <Tip warning={true}>

        This API is 🧪 experimental.

        </Tip>

        Args:
            unet (`bool`, defaults to `True`): To apply fusion on the UNet.
            vae (`bool`, defaults to `True`): To apply fusion on the VAE.

        """
        if unet:
            if not self.fusing_unet:
                logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
            else:
                self.unet.unfuse_qkv_projections()
                self.fusing_unet = False

        if vae:
            if not self.fusing_vae:
                logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
            else:
                self.vae.unfuse_qkv_projections()
                self.fusing_vae = False