scheduling_sde_ve.py 12.2 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Patrick von Platen's avatar
Patrick von Platen committed
15
16
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

17
import warnings
18
19
from dataclasses import dataclass
from typing import Optional, Tuple, Union
20
21
22
23

import numpy as np
import torch

24
from ..configuration_utils import ConfigMixin, register_to_config
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput


@dataclass
class SdeVeOutput(BaseOutput):
    """
    Output class for the ScoreSdeVeScheduler's step function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
            Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
    """

    prev_sample: torch.FloatTensor
    prev_sample_mean: torch.FloatTensor
44
45


Patrick von Platen's avatar
Patrick von Platen committed
46
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Nathan Lambert's avatar
Nathan Lambert committed
47
48
49
    """
    The variance exploding stochastic differential equation (SDE) scheduler.

50
51
52
53
54
55
56
57
58
59
60
61
62
    For more information, see the original paper: https://arxiv.org/abs/2011.13456

    Args:
        snr (`float`):
            coefficient weighting the step from the model_output sample (from the network) to the random noise.
        sigma_min (`float`):
                initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
                distribution of the data.
        sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
        sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
        epsilon.
        correct_steps (`int`): number of correction steps performed on a produced sample.
        tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
Nathan Lambert's avatar
Nathan Lambert committed
63
64
    """

65
    @register_to_config
Nathan Lambert's avatar
Nathan Lambert committed
66
67
    def __init__(
        self,
68
69
70
71
72
73
74
        num_train_timesteps: int = 2000,
        snr: float = 0.15,
        sigma_min: float = 0.01,
        sigma_max: float = 1348.0,
        sampling_eps: float = 1e-5,
        correct_steps: int = 1,
        tensor_format: str = "pt",
Nathan Lambert's avatar
Nathan Lambert committed
75
    ):
76
        # setable values
Patrick von Platen's avatar
Patrick von Platen committed
77
78
        self.timesteps = None

79
        self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
80
81

        self.tensor_format = tensor_format
Nathan Lambert's avatar
Nathan Lambert committed
82
83
        self.set_format(tensor_format=tensor_format)

84
    def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
85
86
87
88
89
90
91
92
93
        """
        Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
94
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Nathan Lambert's avatar
Nathan Lambert committed
95
96
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
97
            self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
Nathan Lambert's avatar
Nathan Lambert committed
98
        elif tensor_format == "pt":
99
            self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
Nathan Lambert's avatar
Nathan Lambert committed
100
101
        else:
            raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
Patrick von Platen's avatar
Patrick von Platen committed
102

103
104
105
    def set_sigmas(
        self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
    ):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        """
        Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.

        The sigmas control the weight of the `drift` and `diffusion` components of sample update.

        Args:
            num_inference_steps (`int`):
                the number of diffusion steps used when generating samples with a pre-trained model.
            sigma_min (`float`, optional):
                initial noise scale value (overrides value given at Scheduler instantiation).
            sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
            sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).

        """
120
121
122
        sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
        sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
        sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
Patrick von Platen's avatar
Patrick von Platen committed
123
        if self.timesteps is None:
124
            self.set_timesteps(num_inference_steps, sampling_eps)
Patrick von Platen's avatar
Patrick von Platen committed
125

Nathan Lambert's avatar
Nathan Lambert committed
126
127
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
128
129
            self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
            self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
130
        elif tensor_format == "pt":
131
132
            self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
            self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
Nathan Lambert's avatar
Nathan Lambert committed
133
134
135
136
137
138
139
140
141
        else:
            raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

    def get_adjacent_sigma(self, timesteps, t):
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
            return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
        elif tensor_format == "pt":
            return torch.where(
142
143
144
                timesteps == 0,
                torch.zeros_like(t.to(timesteps.device)),
                self.discrete_sigmas[timesteps - 1].to(timesteps.device),
Nathan Lambert's avatar
Nathan Lambert committed
145
146
147
148
            )

        raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

149
    def set_seed(self, seed):
150
151
152
153
154
        warnings.warn(
            "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
            " generator instead.",
            DeprecationWarning,
        )
155
156
157
158
159
160
161
162
163
164
165
166
167
        tensor_format = getattr(self, "tensor_format", "pt")
        if tensor_format == "np":
            np.random.seed(seed)
        elif tensor_format == "pt":
            torch.manual_seed(seed)
        else:
            raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

    def step_pred(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
168
        generator: Optional[torch.Generator] = None,
169
        return_dict: bool = True,
170
        **kwargs,
171
    ) -> Union[SdeVeOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
172
        """
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
        process from the learned model outputs (most often the predicted noise).

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            timestep (`int`): current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.

Nathan Lambert's avatar
Nathan Lambert committed
187
        """
188
189
        if "seed" in kwargs and kwargs["seed"] is not None:
            self.set_seed(kwargs["seed"])
190

191
192
193
194
195
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

196
197
198
199
        timestep = timestep * torch.ones(
            sample.shape[0], device=sample.device
        )  # torch.repeat_interleave(timestep, sample.shape[0])
        timesteps = (timestep * (len(self.timesteps) - 1)).long()
Nathan Lambert's avatar
Nathan Lambert committed
200

201
202
203
        # mps requires indices to be in the same device, so we use cpu as is the default with cuda
        timesteps = timesteps.to(self.discrete_sigmas.device)

204
        sigma = self.discrete_sigmas[timesteps].to(sample.device)
205
        adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
206
        drift = self.zeros_like(sample)
Nathan Lambert's avatar
Nathan Lambert committed
207
208
        diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

209
        # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
Nathan Lambert's avatar
Nathan Lambert committed
210
        # also equation 47 shows the analog from SDE models to ancestral sampling methods
211
        drift = drift - diffusion[:, None, None, None] ** 2 * model_output
Nathan Lambert's avatar
Nathan Lambert committed
212
213

        #  equation 6: sample noise for the diffusion term of
214
        noise = self.randn_like(sample, generator=generator)
215
        prev_sample_mean = sample - drift  # subtract because `dt` is a small negative timestep
Nathan Lambert's avatar
Nathan Lambert committed
216
        # TODO is the variable diffusion the correct scaling term for the noise?
217
        prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise  # add impact of diffusion field g
218

219
220
221
222
        if not return_dict:
            return (prev_sample, prev_sample_mean)

        return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
223
224
225
226
227

    def step_correct(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        sample: Union[torch.FloatTensor, np.ndarray],
228
        generator: Optional[torch.Generator] = None,
229
        return_dict: bool = True,
230
        **kwargs,
231
    ) -> Union[SchedulerOutput, Tuple]:
Nathan Lambert's avatar
Nathan Lambert committed
232
        """
233
234
        Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
        after making the prediction for the previous timestep.
235
236
237
238
239
240
241
242
243
244
245

        Args:
            model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
            sample (`torch.FloatTensor` or `np.ndarray`):
                current instance of sample being created by diffusion process.
            generator: random number generator.
            return_dict (`bool`): option for returning tuple rather than SchedulerOutput class

        Returns:
            prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain.

Nathan Lambert's avatar
Nathan Lambert committed
246
        """
247
248
        if "seed" in kwargs and kwargs["seed"] is not None:
            self.set_seed(kwargs["seed"])
249

250
251
252
253
254
        if self.timesteps is None:
            raise ValueError(
                "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
            )

Nathan Lambert's avatar
Nathan Lambert committed
255
256
        # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
        # sample noise for correction
257
        noise = self.randn_like(sample, generator=generator)
258

259
260
        # compute step size from the model_output, the noise, and the snr
        grad_norm = self.norm(model_output)
Nathan Lambert's avatar
Nathan Lambert committed
261
        noise_norm = self.norm(noise)
Patrick von Platen's avatar
Patrick von Platen committed
262
        step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
263
264
        step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
        # self.repeat_scalar(step_size, sample.shape[0])
265

266
267
268
        # compute corrected sample: model_output term and noise term
        prev_sample_mean = sample + step_size[:, None, None, None] * model_output
        prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
269

270
271
272
273
        if not return_dict:
            return (prev_sample,)

        return SchedulerOutput(prev_sample=prev_sample)
Nathan Lambert's avatar
Nathan Lambert committed
274
275
276

    def __len__(self):
        return self.config.num_train_timesteps