scheduling_ddim.py 7.59 KB
Newer Older
1
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from typing import Union
Patrick von Platen's avatar
Patrick von Platen committed
20

Patrick von Platen's avatar
Patrick von Platen committed
21
import numpy as np
22
import torch
Patrick von Platen's avatar
Patrick von Platen committed
23
24

from ..configuration_utils import ConfigMixin
25
26
27
28
29
from .scheduling_utils import SchedulerMixin


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
30
31
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
32

Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
    :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
    from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that part of the diffusion process.
36
37
38
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
39

40
41
42
43
44
45
46
47
48
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
Patrick von Platen committed
49
50


Patrick von Platen's avatar
Patrick von Platen committed
51
class DDIMScheduler(SchedulerMixin, ConfigMixin):
Patrick von Platen's avatar
Patrick von Platen committed
52
53
    def __init__(
        self,
Nathan Lambert's avatar
Nathan Lambert committed
54
        num_train_timesteps=1000,
Patrick von Platen's avatar
Patrick von Platen committed
55
56
57
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
patil-suraj's avatar
patil-suraj committed
58
59
        trained_betas=None,
        timestep_values=None,
Patrick von Platen's avatar
Patrick von Platen committed
60
        clip_sample=True,
Patrick von Platen's avatar
Patrick von Platen committed
61
        tensor_format="np",
Patrick von Platen's avatar
Patrick von Platen committed
62
63
    ):
        super().__init__()
64
        self.register_to_config(
Nathan Lambert's avatar
Nathan Lambert committed
65
            num_train_timesteps=num_train_timesteps,
Patrick von Platen's avatar
Patrick von Platen committed
66
67
68
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule=beta_schedule,
69
70
71
            trained_betas=trained_betas,
            timestep_values=timestep_values,
            clip_sample=clip_sample,
Patrick von Platen's avatar
Patrick von Platen committed
72
73
        )

74
        if beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
75
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
76
77
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
Nathan Lambert's avatar
Nathan Lambert committed
78
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
79
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
80
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
81
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
82
83
84
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
85
86
87
88
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

89
90
        # setable values
        self.num_inference_steps = None
Nathan Lambert's avatar
Nathan Lambert committed
91
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
Patrick von Platen's avatar
Patrick von Platen committed
92

Patrick von Platen's avatar
Patrick von Platen committed
93
94
95
        self.tensor_format = tensor_format
        self.set_format(tensor_format=tensor_format)

96
97
98
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
99
100
101
102
103
104
105
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

106
107
    def set_timesteps(self, num_inference_steps):
        self.num_inference_steps = num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
108
109
110
        self.timesteps = np.arange(
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
        )[::-1].copy()
111
112
113
114
        self.set_format(tensor_format=self.tensor_format)

    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
115
        model_output: Union[torch.FloatTensor, np.ndarray],
116
117
118
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        eta,
Patrick von Platen's avatar
Patrick von Platen committed
119
        use_clipped_model_output=False,
120
121
        generator=None,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
122
123
124
125
126
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
127
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
128
129
        # - std_dev_t -> sigma_t
        # - eta -> η
130
131
        # - pred_sample_direction -> "direction pointingc to x_t"
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
132

133
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
134
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
135
136

        # 2. compute alphas, betas
137
138
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
139
140
        beta_prod_t = 1 - alpha_prod_t

141
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
142
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
143
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
144
145

        # 4. Clip "predicted x_0"
146
        if self.config.clip_sample:
147
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
148
149
150

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
151
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
152
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
153

Patrick von Platen's avatar
Patrick von Platen committed
154
155
156
        if use_clipped_model_output:
            # the model_output is always re-derived from the clipped x_0 in Glide
            model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
157

Patrick von Platen's avatar
Patrick von Platen committed
158
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
159
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
Patrick von Platen's avatar
Patrick von Platen committed
160
161

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
162
163
164
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
Patrick von Platen's avatar
Patrick von Platen committed
165
166
            device = model_output.device if torch.is_tensor(model_output) else "cpu"
            noise = torch.randn(model_output.shape, generator=generator).to(device)
167
168
            variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise

Patrick von Platen's avatar
Patrick von Platen committed
169
            if not torch.is_tensor(model_output):
170
171
172
                variance = variance.numpy()

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
173

174
        return {"prev_sample": prev_sample}
Patrick von Platen's avatar
Patrick von Platen committed
175

176
177
178
179
180
181
182
183
184
    def add_noise(self, original_samples, noise, timesteps):
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
185
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
186
        return self.config.num_train_timesteps