"src/vscode:/vscode.git/clone" did not exist on "21e61eb3a9d16a46245bd284fea3aa19e66772f5"
scheduling_utils.py 8.21 KB
Newer Older
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
up  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
import importlib
import os
16
from dataclasses import dataclass
Kashif Rasul's avatar
Kashif Rasul committed
17
from enum import Enum
Patrick von Platen's avatar
Patrick von Platen committed
18
from typing import Optional, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
up  
Patrick von Platen committed
20
import torch
21
from huggingface_hub.utils import validate_hf_hub_args
Patrick von Platen's avatar
up  
Patrick von Platen committed
22

23
from ..utils import BaseOutput, PushToHubMixin
24

Patrick von Platen's avatar
up  
Patrick von Platen committed
25

Patrick von Platen's avatar
Patrick von Platen committed
26
27
28
SCHEDULER_CONFIG_NAME = "scheduler_config.json"


29
30
31
32
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
# circular imports when used for `_compatibles` within the schedulers module.
# When it's used as a type in pipelines, it really is a Union because the actual
# scheduler instance is passed in.
Kashif Rasul's avatar
Kashif Rasul committed
33
34
35
36
37
38
39
40
41
42
43
44
45
class KarrasDiffusionSchedulers(Enum):
    DDIMScheduler = 1
    DDPMScheduler = 2
    PNDMScheduler = 3
    LMSDiscreteScheduler = 4
    EulerDiscreteScheduler = 5
    HeunDiscreteScheduler = 6
    EulerAncestralDiscreteScheduler = 7
    DPMSolverMultistepScheduler = 8
    DPMSolverSinglestepScheduler = 9
    KDPM2DiscreteScheduler = 10
    KDPM2AncestralDiscreteScheduler = 11
    DEISMultistepScheduler = 12
46
    UniPCMultistepScheduler = 13
47
    DPMSolverSDEScheduler = 14
Suraj Patil's avatar
Suraj Patil committed
48
    EDMEulerScheduler = 15
Kashif Rasul's avatar
Kashif Rasul committed
49
50


51
52
53
@dataclass
class SchedulerOutput(BaseOutput):
    """
54
    Base class for the output of a scheduler's `step` function.
55
56
57

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
58
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
59
60
61
62
63
64
            denoising loop.
    """

    prev_sample: torch.FloatTensor


65
class SchedulerMixin(PushToHubMixin):
66
    """
67
68
69
70
71
72
73
    Base class for all schedulers.

    [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving
    functionalities.

    [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to
    the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`.
74
75

    Class attributes:
76
77
        - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler
          class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden
78
          by parent class).
79
    """
Patrick von Platen's avatar
Patrick von Platen committed
80
81

    config_name = SCHEDULER_CONFIG_NAME
82
83
84
85
    _compatibles = []
    has_compatibles = True

    @classmethod
86
    @validate_hf_hub_args
87
88
    def from_pretrained(
        cls,
Anh71me's avatar
Anh71me committed
89
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
90
91
92
93
94
        subfolder: Optional[str] = None,
        return_unused_kwargs=False,
        **kwargs,
    ):
        r"""
95
        Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
96
97
98
99
100

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

101
102
103
104
                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
                      the Hub.
                    - A path to a *directory* (for example `./my_model_directory`) containing the scheduler
                      configuration saved with [`~SchedulerMixin.save_pretrained`].
105
            subfolder (`str`, *optional*):
106
                The subfolder location of a model file within a larger model repository on the Hub or locally.
107
108
109
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                Whether kwargs that are not consumed by the Python class should be returned or not.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
110
111
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
112
113
114
115
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
116
117
                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
                incompletely downloaded files are deleted.
118
            proxies (`Dict[str, str]`, *optional*):
119
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
120
121
122
123
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
            local_files_only(`bool`, *optional*, defaults to `False`):
124
125
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
126
            token (`str` or *bool*, *optional*):
127
128
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
129
            revision (`str`, *optional*, defaults to `"main"`):
130
131
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
132
133
134

        <Tip>

135
136
137
138
        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
        `huggingface-cli login`. You can also activate the special
        ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
        firewalled environment.
139
140
141
142

        </Tip>

        """
143
        config, kwargs, commit_hash = cls.load_config(
144
145
146
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            subfolder=subfolder,
            return_unused_kwargs=True,
147
            return_commit_hash=True,
148
149
150
151
152
153
            **kwargs,
        )
        return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
154
        Save a scheduler configuration object to a directory so that it can be reloaded using the
155
156
157
158
        [`~SchedulerMixin.from_pretrained`] class method.

        Args:
            save_directory (`str` or `os.PathLike`):
159
160
161
162
163
164
165
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        """
        self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)

    @property
    def compatibles(self):
        """
        Returns all schedulers that are compatible with this scheduler

        Returns:
            `List[SchedulerMixin]`: List of compatible schedulers
        """
        return self._get_compatibles()

    @classmethod
    def _get_compatibles(cls):
        compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        compatible_classes = [
            getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
        ]
        return compatible_classes