scheduling_ddpm_wuerstchen.py 8.72 KB
Newer Older
Kashif Rasul's avatar
Kashif Rasul committed
1
# Copyright (c) 2022 Pablo Pernías MIT License
2
# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
Kashif Rasul's avatar
Kashif Rasul committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import torch

from ..configuration_utils import ConfigMixin, register_to_config
Dhruv Nair's avatar
Dhruv Nair committed
25
26
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
Kashif Rasul's avatar
Kashif Rasul committed
27
28
29
30
31
32
33
34
35
from .scheduling_utils import SchedulerMixin


@dataclass
class DDPMWuerstchenSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
36
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
Kashif Rasul's avatar
Kashif Rasul committed
37
38
39
40
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
    """

41
    prev_sample: torch.Tensor
Kashif Rasul's avatar
Kashif Rasul committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77


def betas_for_alpha_bar(
    num_diffusion_timesteps,
    max_beta=0.999,
    alpha_transform_type="cosine",
):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].

    Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
    to that part of the diffusion process.


    Args:
        num_diffusion_timesteps (`int`): the number of betas to produce.
        max_beta (`float`): the maximum beta to use; use values lower than 1 to
                     prevent singularities.
        alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
                     Choose from `cosine` or `exp`

    Returns:
        betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
    """
    if alpha_transform_type == "cosine":

        def alpha_bar_fn(t):
            return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2

    elif alpha_transform_type == "exp":

        def alpha_bar_fn(t):
            return math.exp(t * -12.0)

    else:
M. Tolga Cangöz's avatar
M. Tolga Cangöz committed
78
        raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
Kashif Rasul's avatar
Kashif Rasul committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
    return torch.tensor(betas, dtype=torch.float32)


class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
    """
    Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
    Langevin dynamics sampling.

    [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
    function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
    [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
    [`~SchedulerMixin.from_pretrained`] functions.

    For more details, see the original paper: https://arxiv.org/abs/2006.11239

    Args:
        scaler (`float`): ....
        s (`float`): ....
    """

    @register_to_config
    def __init__(
        self,
        scaler: float = 1.0,
        s: float = 0.008,
    ):
        self.scaler = scaler
        self.s = torch.tensor([s])
        self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2

        # standard deviation of the initial noise distribution
        self.init_noise_sigma = 1.0

    def _alpha_cumprod(self, t, device):
        if self.scaler > 1:
            t = 1 - (1 - t) ** self.scaler
        elif self.scaler < 1:
            t = t**self.scaler
        alpha_cumprod = torch.cos(
            (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
        ) ** 2 / self._init_alpha_cumprod.to(device)
        return alpha_cumprod.clamp(0.0001, 0.9999)

128
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
Kashif Rasul's avatar
Kashif Rasul committed
129
130
131
132
133
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
134
            sample (`torch.Tensor`): input sample
Kashif Rasul's avatar
Kashif Rasul committed
135
136
137
            timestep (`int`, optional): current timestep

        Returns:
138
            `torch.Tensor`: scaled input sample
Kashif Rasul's avatar
Kashif Rasul committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        """
        return sample

    def set_timesteps(
        self,
        num_inference_steps: int = None,
        timesteps: Optional[List[int]] = None,
        device: Union[str, torch.device] = None,
    ):
        """
        Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`Dict[float, int]`):
                the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
                `timesteps` must be `None`.
            device (`str` or `torch.device`, optional):
                the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
        """
        if timesteps is None:
            timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
        if not isinstance(timesteps, torch.Tensor):
            timesteps = torch.Tensor(timesteps).to(device)
        self.timesteps = timesteps

    def step(
        self,
166
        model_output: torch.Tensor,
Kashif Rasul's avatar
Kashif Rasul committed
167
        timestep: int,
168
        sample: torch.Tensor,
Kashif Rasul's avatar
Kashif Rasul committed
169
170
171
172
173
174
175
176
        generator=None,
        return_dict: bool = True,
    ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
177
            model_output (`torch.Tensor`): direct output from learned diffusion model.
Kashif Rasul's avatar
Kashif Rasul committed
178
            timestep (`int`): current discrete timestep in the diffusion chain.
179
            sample (`torch.Tensor`):
Kashif Rasul's avatar
Kashif Rasul committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class

        Returns:
            [`DDPMWuerstchenSchedulerOutput`] or `tuple`: [`DDPMWuerstchenSchedulerOutput`] if `return_dict` is True,
            otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.

        """
        dtype = model_output.dtype
        device = model_output.device
        t = timestep

        prev_t = self.previous_timestep(t)

        alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
        alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
        alpha = alpha_cumprod / alpha_cumprod_prev

        mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())

        std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
        std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
        pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])

        if not return_dict:
            return (pred.to(dtype),)

        return DDPMWuerstchenSchedulerOutput(prev_sample=pred.to(dtype))

    def add_noise(
        self,
212
213
214
215
        original_samples: torch.Tensor,
        noise: torch.Tensor,
        timesteps: torch.Tensor,
    ) -> torch.Tensor:
216
217
218
219
220
221
222
        device = original_samples.device
        dtype = original_samples.dtype
        alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
            timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
        )
        noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
        return noisy_samples.to(dtype=dtype)
Kashif Rasul's avatar
Kashif Rasul committed
223
224
225
226
227
228
229
230

    def __len__(self):
        return self.config.num_train_timesteps

    def previous_timestep(self, timestep):
        index = (self.timesteps - timestep[0]).abs().argmin().item()
        prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
        return prev_t