scheduling_ddim_inverse.py 17.4 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
26
from diffusers.utils import BaseOutput, deprecate
27
28
29
30
31
32


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class DDIMSchedulerOutput(BaseOutput):
    """
33
    Output class for the scheduler's `step` function output.
34
35

    Args:
36
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
            denoising loop.
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
            The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
42
43
            `pred_original_sample` can be used to preview progress or for guidance.
    """

44
45
    prev_sample: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
46
47
48


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
YiYi Xu's avatar
YiYi Xu committed
49
50
51
52
53
def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
54
55
56
57
58
59
60
61
62
63
64
65
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
YiYi Xu's avatar
YiYi Xu committed
66
67
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`
68
69
70
71

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
YiYi Xu's avatar
YiYi Xu committed
72
    if alpha_transform_type == "cosine":
73

YiYi Xu's avatar
YiYi Xu committed
74
75
76
77
78
79
80
81
82
        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
83
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
84
85
86
87
88

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
YiYi Xu's avatar
YiYi Xu committed
89
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
90
91
92
    return torch.tensor(betas, dtype=torch.float32)


93
94
95
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
    """
Quentin Gallouédec's avatar
Quentin Gallouédec committed
96
    Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
97
98

    Args:
99
        betas (`torch.Tensor`):
100
101
102
            the betas that the scheduler is being initialized with.

    Returns:
103
        `torch.Tensor`: rescaled betas with zero terminal SNR
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_bar_sqrt = alphas_cumprod.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


129
130
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
    """
131
    `DDIMInverseScheduler` is the reverse scheduler of [`DDIMScheduler`].
132

133
134
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
135
136

    Args:
137
138
139
140
141
142
143
144
        num_train_timesteps (`int`, defaults to 1000):
            The number of diffusion steps to train the model.
        beta_start (`float`, defaults to 0.0001):
            The starting `beta` value of inference.
        beta_end (`float`, defaults to 0.02):
            The final `beta` value.
        beta_schedule (`str`, defaults to `"linear"`):
            The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
145
            `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
146
147
148
149
150
151
152
153
154
155
156
        trained_betas (`np.ndarray`, *optional*):
            Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
        clip_sample (`bool`, defaults to `True`):
            Clip the predicted sample for numerical stability.
        clip_sample_range (`float`, defaults to 1.0):
            The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
        set_alpha_to_one (`bool`, defaults to `True`):
            Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
            there is no previous alpha. When this option is `True` the previous alpha product is fixed to 0, otherwise
            it uses the alpha value at step `num_train_timesteps - 1`.
        steps_offset (`int`, defaults to 0):
157
            An offset added to the inference steps, as required by some model families.
158
159
160
161
162
163
164
165
166
167
        prediction_type (`str`, defaults to `epsilon`, *optional*):
            Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
            `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
            Video](https://imagen.research.google/video/paper.pdf) paper).
        timestep_spacing (`str`, defaults to `"leading"`):
            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
        rescale_betas_zero_snr (`bool`, defaults to `False`):
            Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
            dark samples instead of limiting it to samples with medium brightness. Loosely related to
168
            [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
169
170
171
    """

    order = 1
172
173
    ignore_for_config = ["kwargs"]
    _deprecated_kwargs = ["set_alpha_to_zero"]
174
175
176
177
178
179
180
181
182
183

    @register_to_config
    def __init__(
        self,
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
        clip_sample: bool = True,
184
        set_alpha_to_one: bool = True,
185
186
        steps_offset: int = 0,
        prediction_type: str = "epsilon",
187
        clip_sample_range: float = 1.0,
188
189
        timestep_spacing: str = "leading",
        rescale_betas_zero_snr: bool = False,
190
        **kwargs,
191
    ):
192
        if kwargs.get("set_alpha_to_zero", None) is not None:
193
            deprecation_message = (
194
                "The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead."
195
            )
196
197
            deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False)
            set_alpha_to_one = kwargs["set_alpha_to_zero"]
198
199
200
201
202
203
        if trained_betas is not None:
            self.betas = torch.tensor(trained_betas, dtype=torch.float32)
        elif beta_schedule == "linear":
            self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
204
            self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
205
206
207
208
        elif beta_schedule == "squaredcos_cap_v2":
            # Glide cosine schedule
            self.betas = betas_for_alpha_bar(num_train_timesteps)
        else:
209
            raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
210

211
212
213
214
        # Rescale for zero SNR
        if rescale_betas_zero_snr:
            self.betas = rescale_zero_terminal_snr(self.betas)

215
216
217
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

218
        # At every step in inverted ddim, we are looking into the next alphas_cumprod
219
220
        # For the initial step, there is no current alphas_cumprod, and the index is out of bounds
        # `set_alpha_to_one` decides whether we set this parameter simply to one
221
        # in this case, self.step() just output the predicted noise
222
223
        # or whether we use the initial alpha used in training the diffusion model.
        self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
224
225
226
227
228
229
230
231

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

        # setable values
        self.num_inference_steps = None
        self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))

232
    # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
233
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
234
235
236
237
238
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
239
            sample (`torch.Tensor`):
240
241
242
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
243
244

        Returns:
245
            `torch.Tensor`:
246
                A scaled input sample.
247
248
249
250
251
        """
        return sample

    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
        """
252
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
253
254
255

        Args:
            num_inference_steps (`int`):
256
                The number of diffusion steps used when generating samples with a pre-trained model.
257
258
259
260
261
262
263
264
265
266
        """

        if num_inference_steps > self.config.num_train_timesteps:
            raise ValueError(
                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
                f" maximal {self.config.num_train_timesteps} timesteps."
            )

        self.num_inference_steps = num_inference_steps
267

Quentin Gallouédec's avatar
Quentin Gallouédec committed
268
        # "leading" and "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
            timesteps += self.config.steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / self.num_inference_steps
            # creates integer timesteps by multiplying by ratio
            # casting to int to avoid issues when num_inference_step is power of 3
            timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64)
            timesteps -= 1
        else:
            raise ValueError(
                f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
            )

286
287
288
289
        self.timesteps = torch.from_numpy(timesteps).to(device)

    def step(
        self,
290
        model_output: torch.Tensor,
291
        timestep: int,
292
        sample: torch.Tensor,
293
294
        return_dict: bool = True,
    ) -> Union[DDIMSchedulerOutput, Tuple]:
295
296
297
298
299
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
300
            model_output (`torch.Tensor`):
301
302
303
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
304
            sample (`torch.Tensor`):
305
306
307
308
309
310
311
312
                A current instance of a sample created by the diffusion process.
            eta (`float`):
                The weight of noise for added noise in diffusion step.
            use_clipped_model_output (`bool`, defaults to `False`):
                If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
                because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
                clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
                `use_clipped_model_output` has no effect.
313
            variance_noise (`torch.Tensor`):
314
315
316
317
318
319
320
321
322
323
324
325
                Alternative to generating noise with `generator` by directly providing the noise for the variance
                itself. Useful for methods such as [`CycleDiffusion`].
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
                `tuple`.

        Returns:
            [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] is
                returned, otherwise a tuple is returned where the first element is the sample tensor.

        """
326
        # 1. get previous step value (=t+1)
327
        prev_timestep = timestep
Patrick von Platen's avatar
Patrick von Platen committed
328
        timestep = min(
329
            timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
Patrick von Platen's avatar
Patrick von Platen committed
330
        )
331

332
333
        # 2. compute alphas, betas
        # change original implementation to exactly match noise levels for analogous forward process
334
335
        alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
336
337
338
339

        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
Quentin Gallouédec's avatar
Quentin Gallouédec committed
340
        # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
            pred_epsilon = model_output
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
            pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction`"
            )
355

356
357
358
359
360
        # 4. Clip or threshold "predicted x_0"
        if self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )
361

Quentin Gallouédec's avatar
Quentin Gallouédec committed
362
        # 5. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
363
        pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
364

Quentin Gallouédec's avatar
Quentin Gallouédec committed
365
        # 6. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
366
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
367
368

        if not return_dict:
369
370
            return (prev_sample, pred_original_sample)
        return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
371
372
373

    def __len__(self):
        return self.config.num_train_timesteps