scheduling_utils.py 8.12 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
up  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
import importlib
import os
16
from dataclasses import dataclass
Kashif Rasul's avatar
Kashif Rasul committed
17
from enum import Enum
18
from typing import Any, Dict, Optional, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
up  
Patrick von Platen committed
20
21
import torch

22
from ..utils import BaseOutput, PushToHubMixin
23

Patrick von Platen's avatar
up  
Patrick von Platen committed
24

Patrick von Platen's avatar
Patrick von Platen committed
25
26
27
SCHEDULER_CONFIG_NAME = "scheduler_config.json"


28
29
30
31
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
# circular imports when used for `_compatibles` within the schedulers module.
# When it's used as a type in pipelines, it really is a Union because the actual
# scheduler instance is passed in.
Kashif Rasul's avatar
Kashif Rasul committed
32
33
34
35
36
37
38
39
40
41
42
43
44
class KarrasDiffusionSchedulers(Enum):
    DDIMScheduler = 1
    DDPMScheduler = 2
    PNDMScheduler = 3
    LMSDiscreteScheduler = 4
    EulerDiscreteScheduler = 5
    HeunDiscreteScheduler = 6
    EulerAncestralDiscreteScheduler = 7
    DPMSolverMultistepScheduler = 8
    DPMSolverSinglestepScheduler = 9
    KDPM2DiscreteScheduler = 10
    KDPM2AncestralDiscreteScheduler = 11
    DEISMultistepScheduler = 12
45
    UniPCMultistepScheduler = 13
46
    DPMSolverSDEScheduler = 14
Kashif Rasul's avatar
Kashif Rasul committed
47
48


49
50
51
@dataclass
class SchedulerOutput(BaseOutput):
    """
52
    Base class for the output of a scheduler's `step` function.
53
54
55

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
56
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
57
58
59
60
61
62
            denoising loop.
    """

    prev_sample: torch.FloatTensor


63
class SchedulerMixin(PushToHubMixin):
64
    """
65
66
67
68
69
70
71
    Base class for all schedulers.

    [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving
    functionalities.

    [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to
    the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`.
72
73

    Class attributes:
74
75
        - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler
          class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden
76
          by parent class).
77
    """
Patrick von Platen's avatar
Patrick von Platen committed
78
79

    config_name = SCHEDULER_CONFIG_NAME
80
81
82
83
84
85
    _compatibles = []
    has_compatibles = True

    @classmethod
    def from_pretrained(
        cls,
Anh71me's avatar
Anh71me committed
86
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
87
88
89
90
91
        subfolder: Optional[str] = None,
        return_unused_kwargs=False,
        **kwargs,
    ):
        r"""
92
        Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository.
93
94
95
96
97

        Parameters:
            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
                Can be either:

98
99
100
101
                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
                      the Hub.
                    - A path to a *directory* (for example `./my_model_directory`) containing the scheduler
                      configuration saved with [`~SchedulerMixin.save_pretrained`].
102
            subfolder (`str`, *optional*):
103
                The subfolder location of a model file within a larger model repository on the Hub or locally.
104
105
106
            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
                Whether kwargs that are not consumed by the Python class should be returned or not.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
107
108
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
109
110
111
112
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            resume_download (`bool`, *optional*, defaults to `False`):
113
114
                Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
                incompletely downloaded files are deleted.
115
            proxies (`Dict[str, str]`, *optional*):
116
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
117
118
119
120
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            output_loading_info(`bool`, *optional*, defaults to `False`):
                Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
            local_files_only(`bool`, *optional*, defaults to `False`):
121
122
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
123
            use_auth_token (`str` or *bool*, *optional*):
124
125
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
126
            revision (`str`, *optional*, defaults to `"main"`):
127
128
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
129
130
131

        <Tip>

132
133
134
135
        To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
        `huggingface-cli login`. You can also activate the special
        ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
        firewalled environment.
136
137
138
139

        </Tip>

        """
140
        config, kwargs, commit_hash = cls.load_config(
141
142
143
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            subfolder=subfolder,
            return_unused_kwargs=True,
144
            return_commit_hash=True,
145
146
147
148
149
150
            **kwargs,
        )
        return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)

    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
        """
151
        Save a scheduler configuration object to a directory so that it can be reloaded using the
152
153
154
155
        [`~SchedulerMixin.from_pretrained`] class method.

        Args:
            save_directory (`str` or `os.PathLike`):
156
157
158
159
160
161
162
                Directory where the configuration JSON file will be saved (will be created if it does not exist).
            push_to_hub (`bool`, *optional*, defaults to `False`):
                Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
                namespace).
            kwargs (`Dict[str, Any]`, *optional*):
                Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        """
        self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)

    @property
    def compatibles(self):
        """
        Returns all schedulers that are compatible with this scheduler

        Returns:
            `List[SchedulerMixin]`: List of compatible schedulers
        """
        return self._get_compatibles()

    @classmethod
    def _get_compatibles(cls):
        compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
        diffusers_library = importlib.import_module(__name__.split(".")[0])
        compatible_classes = [
            getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
        ]
        return compatible_classes