"vscode:/vscode.git/clone" did not exist on "757eeacc1b34c825f5927d2db86d4e73e8fdf52a"
scheduling_utils.py 8.64 KB
Newer Older
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
up  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
import importlib
import os
16
from dataclasses import dataclass
Kashif Rasul's avatar
Kashif Rasul committed
17
from enum import Enum
Patrick von Platen's avatar
Patrick von Platen committed
18
from typing import Optional, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
up  
Patrick von Platen committed
20
import torch
21
from huggingface_hub.utils import validate_hf_hub_args
Patrick von Platen's avatar
up  
Patrick von Platen committed
22

23
from ..utils import BaseOutput, PushToHubMixin
24

Patrick von Platen's avatar
up  
Patrick von Platen committed
25

Patrick von Platen's avatar
Patrick von Platen committed
26
27
28
SCHEDULER_CONFIG_NAME = "scheduler_config.json"


29
30
31
32
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
# circular imports when used for `_compatibles` within the schedulers module.
# When it's used as a type in pipelines, it really is a Union because the actual
# scheduler instance is passed in.
Kashif Rasul's avatar
Kashif Rasul committed
33
34
35
36
37
38
39
40
41
42
43
44
45
class KarrasDiffusionSchedulers(Enum):
    DDIMScheduler = 1
    DDPMScheduler = 2
    PNDMScheduler = 3
    LMSDiscreteScheduler = 4
    EulerDiscreteScheduler = 5
    HeunDiscreteScheduler = 6
    EulerAncestralDiscreteScheduler = 7
    DPMSolverMultistepScheduler = 8
    DPMSolverSinglestepScheduler = 9
    KDPM2DiscreteScheduler = 10
    KDPM2AncestralDiscreteScheduler = 11
    DEISMultistepScheduler = 12
46
    UniPCMultistepScheduler = 13
47
    DPMSolverSDEScheduler = 14
Suraj Patil's avatar
Suraj Patil committed
48
    EDMEulerScheduler = 15
Kashif Rasul's avatar
Kashif Rasul committed
49
50


51
52
53
54
55
56
57
58
59
AysSchedules = {
    "StableDiffusionTimesteps": [999, 850, 736, 645, 545, 455, 343, 233, 124, 24],
    "StableDiffusionSigmas": [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.0],
    "StableDiffusionXLTimesteps": [999, 845, 730, 587, 443, 310, 193, 116, 53, 13],
    "StableDiffusionXLSigmas": [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0],
    "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0],
}


60
61
62
@dataclass
class SchedulerOutput(BaseOutput):
    """
63
    Base class for the output of a scheduler's `step` function.
64
65
66

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
67
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
68
69
70
71
72
73
            denoising loop.
    """

    prev_sample: torch.FloatTensor


74
class SchedulerMixin(PushToHubMixin):
75
    """
76
77
78
79
80
81
82
    Base class for all schedulers.

    [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving
    functionalities.

    [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to
    the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`.
83
84

    Class attributes:
85
86
        - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler
          class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden
87
          by parent class).
88
    """
Patrick von Platen's avatar
Patrick von Platen committed
89
90

    config_name = SCHEDULER_CONFIG_NAME
91
92
93
94
    _compatibles = []
    has_compatibles = True

    @classmethod
95
    @validate_hf_hub_args
96
97
    def from_pretrained(
        cls,
Anh71me's avatar
Anh71me committed
98
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
99
100
101
102
103
        subfolder: Optional[str] = None,
        return_unused_kwargs=False,
        **kwargs,
    ):
        r"""
104
        Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
105
106
107
108
109

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

110
111
112
113
                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
                      the Hub.
                    - A path to a *directory* (for example `./my_model_directory`) containing the scheduler
                      configuration saved with [`~SchedulerMixin.save_pretrained`].
114
            subfolder (`str`, *optional*):
115
                The subfolder location of a model file within a larger model repository on the Hub or locally.
116
117
118
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                Whether kwargs that are not consumed by the Python class should be returned or not.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
119
120
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
121
122
123
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
124
125
126
            resume_download:
                Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
                of Diffusers.
127
            proxies (`Dict[str, str]`, *optional*):
128
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
129
130
131
132
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
            local_files_only(`bool`, *optional*, defaults to `False`):
133
134
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
135
            token (`str` or *bool*, *optional*):
136
137
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
138
            revision (`str`, *optional*, defaults to `"main"`):
139
140
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
141
142
143

        <Tip>

144
145
146
147
        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
        `huggingface-cli login`. You can also activate the special
        ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
        firewalled environment.
148
149
150
151

        </Tip>

        """
152
        config, kwargs, commit_hash = cls.load_config(
153
154
155
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            subfolder=subfolder,
            return_unused_kwargs=True,
156
            return_commit_hash=True,
157
158
159
160
161
162
            **kwargs,
        )
        return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
163
        Save a scheduler configuration object to a directory so that it can be reloaded using the
164
165
166
167
        [`~SchedulerMixin.from_pretrained`] class method.

        Args:
            save_directory (`str` or `os.PathLike`):
168
169
170
171
172
173
174
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        """
        self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)

    @property
    def compatibles(self):
        """
        Returns all schedulers that are compatible with this scheduler

        Returns:
            `List[SchedulerMixin]`: List of compatible schedulers
        """
        return self._get_compatibles()

    @classmethod
    def _get_compatibles(cls):
        compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        compatible_classes = [
            getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
        ]
        return compatible_classes