scheduling_ddim_inverse.py 17.5 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
19
from typing import List, Literal, Optional, Tuple, Union
20
21
22
23
24
25

import numpy as np
import torch

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
26
from diffusers.utils import BaseOutput, deprecate
27
28
29
30
31
32


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class DDIMSchedulerOutput(BaseOutput):
    """
33
    Output class for the scheduler's `step` function output.
34
35

    Args:
36
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
            denoising loop.
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
42
43
            `pred_original_sample` can be used to preview progress or for guidance.
    """

44
45
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
46
47
48


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
49
def betas_for_alpha_bar(
50
51
52
53
    num_diffusion_timesteps: int,
    max_beta: float = 0.999,
    alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
54
55
56
57
58
59
60
61
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.

    Args:
62
63
64
65
66
67
        num_diffusion_timesteps (`int`):
            The number of betas to produce.
        max_beta (`float`, defaults to `0.999`):
            The maximum beta to use; use values lower than 1 to avoid numerical instability.
        alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
            The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
68
69

    Returns:
70
71
        `torch.Tensor`:
            The betas used by the scheduler to step the model outputs.
72
    """
YiYi Xu's avatar
YiYi Xu committed
73
    if alpha_transform_type == "cosine":
74

YiYi Xu's avatar
YiYi Xu committed
75
76
77
78
79
80
81
82
83
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
84
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
85
86
87
88
89

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
90
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
91
92
93
    return torch.tensor(betas, dtype=torch.float32)


94
95
96
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
97
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
98
99

    Args:
100
        betas (`torch.Tensor`):
101
            The betas that the scheduler is being initialized with.
102
103

    Returns:
104
105
        `torch.Tensor`:
            Rescaled betas with zero terminal SNR.
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


131
132
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
    """
133
    `DDIMInverseScheduler` is the reverse scheduler of [`DDIMScheduler`].
134

135
136
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
137
138

    Args:
139
140
141
142
143
144
145
146
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
147
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
148
149
150
151
152
153
154
155
156
157
158
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        clip_sample (`bool`, defaults to `True`):
            Clip the predicted sample for numerical stability.
        clip_sample_range (`float`, defaults to 1.0):
            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
        set_alpha_to_one (`bool`, defaults to `True`):
            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
            there is no previous alpha. When this option is `True` the previous alpha product is fixed to 0, otherwise
            it uses the alpha value at step `num_train_timesteps - 1`.
        steps_offset (`int`, defaults to 0):
159
            An offset added to the inference steps, as required by some model families.
160
161
162
163
164
165
166
167
168
169
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        timestep_spacing (`str`, defaults to `"leading"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
170
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
171
172
173
    """

    order = 1
174
175
    ignore_for_config = ["kwargs"]
    _deprecated_kwargs = ["set_alpha_to_zero"]
176
177
178
179
180
181
182
183
184
185

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        clip_sample: bool = True,
186
        set_alpha_to_one: bool = True,
187
188
        steps_offset: int = 0,
        prediction_type: str = "epsilon",
189
        clip_sample_range: float = 1.0,
190
191
        timestep_spacing: str = "leading",
        rescale_betas_zero_snr: bool = False,
192
        **kwargs,
193
    ):
194
        if kwargs.get("set_alpha_to_zero", None) is not None:
195
            deprecation_message = (
196
                "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead."
197
            )
198
199
            deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False)
            set_alpha_to_one = kwargs["set_alpha_to_zero"]
200
201
202
203
204
205
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
206
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
207
208
209
210
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
211
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
212

213
214
215
216
        # Rescale for zero SNR
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

217
218
219
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

220
        # At every step in inverted ddim, we are looking into the next alphas_cumprod
221
222
        # For the initial step, there is no current alphas_cumprod, and the index is out of bounds
        # `set_alpha_to_one` decides whether we set this parameter simply to one
223
        # in this case, self.step() just output the predicted noise
224
225
        # or whether we use the initial alpha used in training the diffusion model.
        self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
226
227
228
229
230
231
232
233

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # setable values
        self.num_inference_steps = None
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))

234
    # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
235
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
236
237
238
239
240
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
241
            sample (`torch.Tensor`):
242
243
244
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
245
246

        Returns:
247
            `torch.Tensor`:
248
                A scaled input sample.
249
250
251
252
253
        """
        return sample

    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
254
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
255
256
257

        Args:
            num_inference_steps (`int`):
258
                The number of diffusion steps used when generating samples with a pre-trained model.
259
260
261
262
263
264
265
266
267
268
        """

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

        self.num_inference_steps = num_inference_steps
269

Quentin Gallouédec's avatar
Quentin Gallouédec committed
270
        # "leading" and "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        if self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
            )

288
289
290
291
        self.timesteps = torch.from_numpy(timesteps).to(device)

    def step(
        self,
292
        model_output: torch.Tensor,
293
        timestep: int,
294
        sample: torch.Tensor,
295
296
        return_dict: bool = True,
    ) -> Union[DDIMSchedulerOutput, Tuple]:
297
298
299
300
301
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
302
            model_output (`torch.Tensor`):
303
304
305
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
306
            sample (`torch.Tensor`):
307
308
309
310
311
312
313
314
                A current instance of a sample created by the diffusion process.
            eta (`float`):
                The weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`, defaults to `False`):
                If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
                clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
                `use_clipped_model_output` has no effect.
315
            variance_noise (`torch.Tensor`):
316
317
318
319
320
321
322
323
324
325
326
327
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`CycleDiffusion`].
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
                `tuple`.

        Returns:
            [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor.

        """
328
        # 1. get previous step value (=t+1)
329
        prev_timestep = timestep
Patrick von Platen's avatar
Patrick von Platen committed
330
        timestep = min(
331
            timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
Patrick von Platen's avatar
Patrick von Platen committed
332
        )
333

334
335
        # 2. compute alphas, betas
        # change original implementation to exactly match noise levels for analogous forward process
336
337
        alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
338
339
340
341

        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
Quentin Gallouédec's avatar
Quentin Gallouédec committed
342
        # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
            pred_epsilon = model_output
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction`"
            )
357

358
359
360
361
362
        # 4. Clip or threshold "predicted x_0"
        if self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )
363

Quentin Gallouédec's avatar
Quentin Gallouédec committed
364
        # 5. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
365
        pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
366

Quentin Gallouédec's avatar
Quentin Gallouédec committed
367
        # 6. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
368
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
369
370

        if not return_dict:
371
372
            return (prev_sample, pred_original_sample)
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
373
374
375

    def __len__(self):
        return self.config.num_train_timesteps