scheduling_karras_ve.py 9.49 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 NVIDIA and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16
from dataclasses import dataclass
17
from typing import Optional, Tuple, Union
18
19
20
21

import numpy as np
import torch

22
23
24
25
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..scheduling_utils import SchedulerMixin
26
27


28
29
30
31
32
33
@dataclass
class KarrasVeOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
34
        prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
35
36
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
37
        derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
            Derivative of predicted original image sample (x_0).
39
        pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
41
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
42
43
    """

44
45
46
    prev_sample: torch.Tensor
    derivative: torch.Tensor
    pred_original_sample: Optional[torch.Tensor] = None
47
48


49
50
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
    """
51
    A stochastic scheduler tailored to variance-expanding models.
52

53
54
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
55

Steven Liu's avatar
Steven Liu committed
56
57
58
    > [!TIP] > For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid
    search > values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in
    Table 5 of > the paper.
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
    Args:
        sigma_min (`float`, defaults to 0.02):
            The minimum noise magnitude.
        sigma_max (`float`, defaults to 100):
            The maximum noise magnitude.
        s_noise (`float`, defaults to 1.007):
            The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
            1.011].
        s_churn (`float`, defaults to 80):
            The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
        s_min (`float`, defaults to 0.05):
            The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
        s_max (`float`, defaults to 50):
            The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
74
75
    """

76
77
    order = 2

78
79
80
    @register_to_config
    def __init__(
        self,
81
82
83
84
85
86
        sigma_min: float = 0.02,
        sigma_max: float = 100,
        s_noise: float = 1.007,
        s_churn: float = 80,
        s_min: float = 0.05,
        s_max: float = 50,
87
    ):
88
89
90
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

91
        # setable values
92
        self.num_inference_steps: int = None
93
        self.timesteps: np.IntTensor = None
94
        self.schedule: torch.Tensor = None  # sigma(t_i)
95

96
    def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
97
98
99
100
101
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
102
            sample (`torch.Tensor`):
103
104
105
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
106
107

        Returns:
108
            `torch.Tensor`:
109
                A scaled input sample.
110
111
112
        """
        return sample

113
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
114
        """
115
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
116
117
118

        Args:
            num_inference_steps (`int`):
119
120
121
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
122
        """
123
        self.num_inference_steps = num_inference_steps
124
125
        timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps).to(device)
126
        schedule = [
127
            (
128
                self.config.sigma_max**2
129
130
                * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
            )
131
132
            for i in self.timesteps
        ]
133
        self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
134

135
    def add_noise_to_input(
136
137
        self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
    ) -> Tuple[torch.Tensor, float]:
138
        """
139
140
        Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
        higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
141

142
        Args:
143
            sample (`torch.Tensor`):
144
145
146
147
                The input sample.
            sigma (`float`):
            generator (`torch.Generator`, *optional*):
                A random number generator.
148
        """
149
150
        if self.config.s_min <= sigma <= self.config.s_max:
            gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
151
152
153
154
        else:
            gamma = 0

        # sample eps ~ N(0, S_noise^2 * I)
155
        eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
156
157
158
159
160
161
162
        sigma_hat = sigma + gamma * sigma
        sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)

        return sample_hat, sigma_hat

    def step(
        self,
163
        model_output: torch.Tensor,
164
165
        sigma_hat: float,
        sigma_prev: float,
166
        sample_hat: torch.Tensor,
167
168
        return_dict: bool = True,
    ) -> Union[KarrasVeOutput, Tuple]:
169
        """
170
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
171
172
173
        process from the learned model outputs (most often the predicted noise).

        Args:
174
            model_output (`torch.Tensor`):
175
176
177
                The direct output from learned diffusion model.
            sigma_hat (`float`):
            sigma_prev (`float`):
178
            sample_hat (`torch.Tensor`):
179
180
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
181

182
        Returns:
183
184
185
            [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned,
                otherwise a tuple is returned where the first element is the sample tensor.
186
187

        """
188

189
190
191
192
        pred_original_sample = sample_hat + sigma_hat * model_output
        derivative = (sample_hat - pred_original_sample) / sigma_hat
        sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative

193
194
195
        if not return_dict:
            return (sample_prev, derivative)

196
197
198
        return KarrasVeOutput(
            prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
        )
199
200
201

    def step_correct(
        self,
202
        model_output: torch.Tensor,
203
204
        sigma_hat: float,
        sigma_prev: float,
205
206
207
        sample_hat: torch.Tensor,
        sample_prev: torch.Tensor,
        derivative: torch.Tensor,
208
209
        return_dict: bool = True,
    ) -> Union[KarrasVeOutput, Tuple]:
210
        """
211
        Corrects the predicted sample based on the `model_output` of the network.
212
213

        Args:
214
            model_output (`torch.Tensor`):
215
                The direct output from learned diffusion model.
216
217
            sigma_hat (`float`): TODO
            sigma_prev (`float`): TODO
218
219
220
            sample_hat (`torch.Tensor`): TODO
            sample_prev (`torch.Tensor`): TODO
            derivative (`torch.Tensor`): TODO
221
222
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
223
224
225

        Returns:
            prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
226

227
        """
228
229
230
        pred_original_sample = sample_prev + sigma_prev * model_output
        derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
        sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
231
232
233
234

        if not return_dict:
            return (sample_prev, derivative)

235
236
237
        return KarrasVeOutput(
            prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
        )
238
239
240

    def add_noise(self, original_samples, noise, timesteps):
        raise NotImplementedError()