scheduling_ipndm.py 9.54 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 Zhejiang University Team and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
16
from typing import List, Optional, Tuple, Union
17

18
import numpy as np
19
20
21
22
23
24
25
26
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput


class IPNDMScheduler(SchedulerMixin, ConfigMixin):
    """
27
    A fourth-order Improved Pseudo Linear Multistep scheduler.
28

29
30
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
31
32

    Args:
33
34
35
36
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
37
38
    """

39
40
    order = 1

41
    @register_to_config
42
43
44
    def __init__(
        self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
    ):
45
46
47
48
49
50
51
        # set `betas`, `alphas`, `timesteps`
        self.set_timesteps(num_train_timesteps)

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # For now we only support F-PNDM, i.e. the runge-kutta method
Quentin Gallouédec's avatar
Quentin Gallouédec committed
52
        # For more information on the algorithm please take a look at the paper: https://huggingface.co/papers/2202.09778
53
54
55
56
57
        # mainly at formula (9), (12), (13) and the Algorithm 2.
        self.pndm_order = 4

        # running values
        self.ets = []
YiYi Xu's avatar
YiYi Xu committed
58
        self._step_index = None
59
        self._begin_index = None
YiYi Xu's avatar
YiYi Xu committed
60
61
62
63

    @property
    def step_index(self):
        """
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
64
        The index counter for current timestep. It will increase 1 after each scheduler step.
YiYi Xu's avatar
YiYi Xu committed
65
66
        """
        return self._step_index
67

68
69
70
71
72
73
74
75
76
77
78
79
80
    @property
    def begin_index(self):
        """
        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
        """
        return self._begin_index

    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
    def set_begin_index(self, begin_index: int = 0):
        """
        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

        Args:
81
            begin_index (`int`, defaults to `0`):
82
83
84
85
                The begin index for the scheduler.
        """
        self._begin_index = begin_index

86
87
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
88
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
89
90
91

        Args:
            num_inference_steps (`int`):
92
93
94
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95
96
97
98
99
        """
        self.num_inference_steps = num_inference_steps
        steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
        steps = torch.cat([steps, torch.tensor([0.0])])

100
101
102
103
104
        if self.config.trained_betas is not None:
            self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
        else:
            self.betas = torch.sin(steps * math.pi / 2) ** 2

105
106
107
108
109
110
        self.alphas = (1.0 - self.betas**2) ** 0.5

        timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
        self.timesteps = timesteps.to(device)

        self.ets = []
YiYi Xu's avatar
YiYi Xu committed
111
        self._step_index = None
112
        self._begin_index = None
YiYi Xu's avatar
YiYi Xu committed
113

114
    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def index_for_timestep(
        self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
    ) -> int:
        """
        Find the index of a given timestep in the timestep schedule.

        Args:
            timestep (`float` or `torch.Tensor`):
                The timestep value to find in the schedule.
            schedule_timesteps (`torch.Tensor`, *optional*):
                The timestep schedule to search in. If `None`, uses `self.timesteps`.

        Returns:
            `int`:
                The index of the timestep in the schedule. For the very first step, returns the second index if
                multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
        """
132
133
        if schedule_timesteps is None:
            schedule_timesteps = self.timesteps
YiYi Xu's avatar
YiYi Xu committed
134

135
        indices = (schedule_timesteps == timestep).nonzero()
YiYi Xu's avatar
YiYi Xu committed
136
137
138
139
140

        # The sigma index that is taken for the **very** first `step`
        # is always the second index (or the last index if there is only 1)
        # This way we can ensure we don't accidentally skip a sigma in
        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
141
        pos = 1 if len(indices) > 1 else 0
YiYi Xu's avatar
YiYi Xu committed
142

143
144
145
        return indices[pos].item()

    # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
146
147
148
149
150
151
152
153
    def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
        """
        Initialize the step index for the scheduler based on the given timestep.

        Args:
            timestep (`float` or `torch.Tensor`):
                The current timestep to initialize the step index from.
        """
154
155
156
157
158
159
        if self.begin_index is None:
            if isinstance(timestep, torch.Tensor):
                timestep = timestep.to(self.timesteps.device)
            self._step_index = self.index_for_timestep(timestep)
        else:
            self._step_index = self._begin_index
160
161
162

    def step(
        self,
163
        model_output: torch.Tensor,
164
        timestep: Union[int, torch.Tensor],
165
        sample: torch.Tensor,
166
167
168
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:
        """
169
170
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
        the linear multistep method. It performs one forward pass multiple times to approximate the solution.
171
172

        Args:
173
            model_output (`torch.Tensor`):
174
175
176
                The direct output from learned diffusion model.
            timestep (`int`):
                The current discrete timestep in the diffusion chain.
177
            sample (`torch.Tensor`):
178
179
180
                A current instance of a sample created by the diffusion process.
            return_dict (`bool`):
                Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
181
182

        Returns:
183
184
185
            [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
                tuple is returned where the first element is the sample tensor.
186
187
188
189
190
        """
        if self.num_inference_steps is None:
            raise ValueError(
                "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
            )
YiYi Xu's avatar
YiYi Xu committed
191
192
        if self.step_index is None:
            self._init_step_index(timestep)
193

YiYi Xu's avatar
YiYi Xu committed
194
195
        timestep_index = self.step_index
        prev_timestep_index = self.step_index + 1
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

        ets = sample * self.betas[timestep_index] + model_output * self.alphas[timestep_index]
        self.ets.append(ets)

        if len(self.ets) == 1:
            ets = self.ets[-1]
        elif len(self.ets) == 2:
            ets = (3 * self.ets[-1] - self.ets[-2]) / 2
        elif len(self.ets) == 3:
            ets = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
        else:
            ets = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])

        prev_sample = self._get_prev_sample(sample, timestep_index, prev_timestep_index, ets)

YiYi Xu's avatar
YiYi Xu committed
211
212
213
        # upon completion increase step index by one
        self._step_index += 1

214
215
216
217
218
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)

219
    def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
220
221
222
223
224
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
225
            sample (`torch.Tensor`):
226
                The input sample.
227
228

        Returns:
229
            `torch.Tensor`:
230
                A scaled input sample.
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        """
        return sample

    def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
        alpha = self.alphas[timestep_index]
        sigma = self.betas[timestep_index]

        next_alpha = self.alphas[prev_timestep_index]
        next_sigma = self.betas[prev_timestep_index]

        pred = (sample - sigma * ets) / max(alpha, 1e-8)
        prev_sample = next_alpha * pred + ets * next_sigma

        return prev_sample

    def __len__(self):
        return self.config.num_train_timesteps