scheduling_karras_ve.py 9.58 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


16
from dataclasses import dataclass
17
from typing import Optional, Tuple, Union
18
19
20
21
22

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
23
from ..utils import BaseOutput, randn_tensor
24
25
26
from .scheduling_utils import SchedulerMixin


27
28
29
30
31
32
33
34
35
36
@dataclass
class KarrasVeOutput(BaseOutput):
    """
    Output class for the scheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
            Derivative of predicted original image sample (x_0).
38
39
40
        pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            The predicted denoised sample (x_{0}) based on the model output from the current timestep.
            `pred_original_sample` can be used to preview progress or for guidance.
41
42
43
44
    """

    prev_sample: torch.FloatTensor
    derivative: torch.FloatTensor
45
    pred_original_sample: Optional[torch.FloatTensor] = None
46
47


48
49
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
    """
50
    A stochastic scheduler tailored to variance-expanding models.
51

52
53
    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
    methods the library implements for all schedulers such as loading and saving.
54

55
    <Tip>
56

57
58
    For more details on the parameters, see [Appendix E](https://arxiv.org/abs/2206.00364). The grid search values used
    to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of the paper.
59

60
    </Tip>
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
    Args:
        sigma_min (`float`, defaults to 0.02):
            The minimum noise magnitude.
        sigma_max (`float`, defaults to 100):
            The maximum noise magnitude.
        s_noise (`float`, defaults to 1.007):
            The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
            1.011].
        s_churn (`float`, defaults to 80):
            The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
        s_min (`float`, defaults to 0.05):
            The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
        s_max (`float`, defaults to 50):
            The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
76
77
    """

78
79
    order = 2

80
81
82
    @register_to_config
    def __init__(
        self,
83
84
85
86
87
88
        sigma_min: float = 0.02,
        sigma_max: float = 100,
        s_noise: float = 1.007,
        s_churn: float = 80,
        s_min: float = 0.05,
        s_max: float = 50,
89
    ):
90
91
92
        # standard deviation of the initial noise distribution
        self.init_noise_sigma = sigma_max

93
        # setable values
94
        self.num_inference_steps: int = None
95
        self.timesteps: np.IntTensor = None
96
        self.schedule: torch.FloatTensor = None  # sigma(t_i)
97

98
99
100
101
102
103
    def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
        """
        Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
        current timestep.

        Args:
104
105
106
107
            sample (`torch.FloatTensor`):
                The input sample.
            timestep (`int`, *optional*):
                The current timestep in the diffusion chain.
108
109

        Returns:
110
111
            `torch.FloatTensor`:
                A scaled input sample.
112
113
114
        """
        return sample

115
    def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
116
        """
117
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
118
119
120

        Args:
            num_inference_steps (`int`):
121
122
123
                The number of diffusion steps used when generating samples with a pre-trained model.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
124
        """
125
        self.num_inference_steps = num_inference_steps
126
127
        timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
        self.timesteps = torch.from_numpy(timesteps).to(device)
128
        schedule = [
129
            (
130
                self.config.sigma_max**2
131
132
                * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
            )
133
134
            for i in self.timesteps
        ]
135
        self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
136

137
    def add_noise_to_input(
138
139
        self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
    ) -> Tuple[torch.FloatTensor, float]:
140
        """
141
142
        Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
        higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
143

144
145
146
147
148
149
        Args:
            sample (`torch.FloatTensor`):
                The input sample.
            sigma (`float`):
            generator (`torch.Generator`, *optional*):
                A random number generator.
150
        """
151
152
        if self.config.s_min <= sigma <= self.config.s_max:
            gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
153
154
155
156
        else:
            gamma = 0

        # sample eps ~ N(0, S_noise^2 * I)
157
        eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
158
159
160
161
162
163
164
        sigma_hat = sigma + gamma * sigma
        sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)

        return sample_hat, sigma_hat

    def step(
        self,
165
        model_output: torch.FloatTensor,
166
167
        sigma_hat: float,
        sigma_prev: float,
168
        sample_hat: torch.FloatTensor,
169
170
        return_dict: bool = True,
    ) -> Union[KarrasVeOutput, Tuple]:
171
        """
172
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
173
174
175
        process from the learned model outputs (most often the predicted noise).

        Args:
176
177
178
179
180
181
182
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            sigma_hat (`float`):
            sigma_prev (`float`):
            sample_hat (`torch.FloatTensor`):
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
183

184
        Returns:
185
186
187
            [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`:
                If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned,
                otherwise a tuple is returned where the first element is the sample tensor.
188
189

        """
190

191
192
193
194
        pred_original_sample = sample_hat + sigma_hat * model_output
        derivative = (sample_hat - pred_original_sample) / sigma_hat
        sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative

195
196
197
        if not return_dict:
            return (sample_prev, derivative)

198
199
200
        return KarrasVeOutput(
            prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
        )
201
202
203

    def step_correct(
        self,
204
        model_output: torch.FloatTensor,
205
206
        sigma_hat: float,
        sigma_prev: float,
207
208
209
        sample_hat: torch.FloatTensor,
        sample_prev: torch.FloatTensor,
        derivative: torch.FloatTensor,
210
211
        return_dict: bool = True,
    ) -> Union[KarrasVeOutput, Tuple]:
212
        """
213
        Corrects the predicted sample based on the `model_output` of the network.
214
215

        Args:
216
217
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
218
219
            sigma_hat (`float`): TODO
            sigma_prev (`float`): TODO
220
221
222
            sample_hat (`torch.FloatTensor`): TODO
            sample_prev (`torch.FloatTensor`): TODO
            derivative (`torch.FloatTensor`): TODO
223
224
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
225
226
227

        Returns:
            prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
228

229
        """
230
231
232
        pred_original_sample = sample_prev + sigma_prev * model_output
        derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
        sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
233
234
235
236

        if not return_dict:
            return (sample_prev, derivative)

237
238
239
        return KarrasVeOutput(
            prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
        )
240
241
242

    def add_noise(self, original_samples, noise, timesteps):
        raise NotImplementedError()