scheduling_karras_ve.py 9.5 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 NVIDIA and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16
from dataclasses import dataclass
17
from typing import Optional, Tuple, Union
18
19
20
21

import numpy as np
import torch

22
23
24
25
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..scheduling_utils import SchedulerMixin
26
27


28
29
30
31
32
33
@dataclass
class KarrasVeOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
34
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
35
36
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
37
        derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
            Derivative of predicted original image sample (x_0).
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
41
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
42
43
    """

44
45
46
    prev_sample: torch.Tensor
    derivative: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
47
48


49
50
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
    """
51
    A stochastic scheduler tailored to variance-expanding models.
52

53
54
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
55

56
    <Tip>
57

Quentin Gallouédec's avatar
Quentin Gallouédec committed
58
59
60
    For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid search
    values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of
    the paper.
61

62
    </Tip>
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
    Args:
        sigma_min (`float`, defaults to 0.02):
            The minimum noise magnitude.
        sigma_max (`float`, defaults to 100):
            The maximum noise magnitude.
        s_noise (`float`, defaults to 1.007):
            The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
            1.011].
        s_churn (`float`, defaults to 80):
            The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
        s_min (`float`, defaults to 0.05):
            The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
        s_max (`float`, defaults to 50):
            The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
78
79
    """

80
81
    order = 2

82
83
84
    @register_to_config
    def __init__(
        self,
85
86
87
88
89
90
        sigma_min: float = 0.02,
        sigma_max: float = 100,
        s_noise: float = 1.007,
        s_churn: float = 80,
        s_min: float = 0.05,
        s_max: float = 50,
91
    ):
92
93
94
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

95
        # setable values
96
        self.num_inference_steps: int = None
97
        self.timesteps: np.IntTensor = None
98
        self.schedule: torch.Tensor = None  # sigma(t_i)
99

100
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
101
102
103
104
105
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
106
            sample (`torch.Tensor`):
107
108
109
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
110
111

        Returns:
112
            `torch.Tensor`:
113
                A scaled input sample.
114
115
116
        """
        return sample

117
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
118
        """
119
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
120
121
122

        Args:
            num_inference_steps (`int`):
123
124
125
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
126
        """
127
        self.num_inference_steps = num_inference_steps
128
129
        timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps).to(device)
130
        schedule = [
131
            (
132
                self.config.sigma_max**2
133
134
                * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
            )
135
136
            for i in self.timesteps
        ]
137
        self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
138

139
    def add_noise_to_input(
140
141
        self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
    ) -> Tuple[torch.Tensor, float]:
142
        """
143
144
        Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
        higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
145

146
        Args:
147
            sample (`torch.Tensor`):
148
149
150
151
                The input sample.
            sigma (`float`):
            generator (`torch.Generator`, *optional*):
                A random number generator.
152
        """
153
154
        if self.config.s_min <= sigma <= self.config.s_max:
            gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
155
156
157
158
        else:
            gamma = 0

        # sample eps ~ N(0, S_noise^2 * I)
159
        eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
160
161
162
163
164
165
166
        sigma_hat = sigma + gamma * sigma
        sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)

        return sample_hat, sigma_hat

    def step(
        self,
167
        model_output: torch.Tensor,
168
169
        sigma_hat: float,
        sigma_prev: float,
170
        sample_hat: torch.Tensor,
171
172
        return_dict: bool = True,
    ) -> Union[KarrasVeOutput, Tuple]:
173
        """
174
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
175
176
177
        process from the learned model outputs (most often the predicted noise).

        Args:
178
            model_output (`torch.Tensor`):
179
180
181
                The direct output from learned diffusion model.
            sigma_hat (`float`):
            sigma_prev (`float`):
182
            sample_hat (`torch.Tensor`):
183
184
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
185

186
        Returns:
187
188
189
            [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned,
                otherwise a tuple is returned where the first element is the sample tensor.
190
191

        """
192

193
194
195
196
        pred_original_sample = sample_hat + sigma_hat * model_output
        derivative = (sample_hat - pred_original_sample) / sigma_hat
        sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative

197
198
199
        if not return_dict:
            return (sample_prev, derivative)

200
201
202
        return KarrasVeOutput(
            prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
        )
203
204
205

    def step_correct(
        self,
206
        model_output: torch.Tensor,
207
208
        sigma_hat: float,
        sigma_prev: float,
209
210
211
        sample_hat: torch.Tensor,
        sample_prev: torch.Tensor,
        derivative: torch.Tensor,
212
213
        return_dict: bool = True,
    ) -> Union[KarrasVeOutput, Tuple]:
214
        """
215
        Corrects the predicted sample based on the `model_output` of the network.
216
217

        Args:
218
            model_output (`torch.Tensor`):
219
                The direct output from learned diffusion model.
220
221
            sigma_hat (`float`): TODO
            sigma_prev (`float`): TODO
222
223
224
            sample_hat (`torch.Tensor`): TODO
            sample_prev (`torch.Tensor`): TODO
            derivative (`torch.Tensor`): TODO
225
226
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
227
228
229

        Returns:
            prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
230

231
        """
232
233
234
        pred_original_sample = sample_prev + sigma_prev * model_output
        derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
        sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
235
236
237
238

        if not return_dict:
            return (sample_prev, derivative)

239
240
241
        return KarrasVeOutput(
            prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
        )
242
243
244

    def add_noise(self, original_samples, noise, timesteps):
        raise NotImplementedError()