scheduling_ddim.py 7.29 KB
Newer Older
1
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16
17

# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

Patrick von Platen's avatar
Patrick von Platen committed
18
import math
19
from typing import Union
Patrick von Platen's avatar
Patrick von Platen committed
20

Patrick von Platen's avatar
Patrick von Platen committed
21
import numpy as np
22
import torch
Patrick von Platen's avatar
Patrick von Platen committed
23

24
from ..configuration_utils import ConfigMixin, register_to_config
25
26
27
28
29
from .scheduling_utils import SchedulerMixin


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
30
31
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
32

Patrick von Platen's avatar
Patrick von Platen committed
33
34
35
    :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
    from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that part of the diffusion process.
36
37
38
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
39

40
41
42
43
44
45
46
47
48
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
Patrick von Platen committed
49
50


Patrick von Platen's avatar
Patrick von Platen committed
51
class DDIMScheduler(SchedulerMixin, ConfigMixin):
52
    @register_to_config
Patrick von Platen's avatar
Patrick von Platen committed
53
54
    def __init__(
        self,
Nathan Lambert's avatar
Nathan Lambert committed
55
        num_train_timesteps=1000,
Patrick von Platen's avatar
Patrick von Platen committed
56
57
58
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
patil-suraj's avatar
patil-suraj committed
59
60
        trained_betas=None,
        timestep_values=None,
Patrick von Platen's avatar
Patrick von Platen committed
61
        clip_sample=True,
Patrick von Platen's avatar
Patrick von Platen committed
62
        tensor_format="np",
Patrick von Platen's avatar
Patrick von Platen committed
63
64
    ):

65
        if beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
66
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
67
68
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
Nathan Lambert's avatar
Nathan Lambert committed
69
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
Patrick von Platen's avatar
Patrick von Platen committed
70
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
71
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
72
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
Patrick von Platen committed
73
74
75
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
76
77
78
79
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

80
81
        # setable values
        self.num_inference_steps = None
Nathan Lambert's avatar
Nathan Lambert committed
82
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
Patrick von Platen's avatar
Patrick von Platen committed
83

Patrick von Platen's avatar
Patrick von Platen committed
84
85
86
        self.tensor_format = tensor_format
        self.set_format(tensor_format=tensor_format)

87
88
89
    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
90
91
92
93
94
95
96
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

97
98
    def set_timesteps(self, num_inference_steps):
        self.num_inference_steps = num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
99
100
101
        self.timesteps = np.arange(
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
        )[::-1].copy()
102
103
104
105
        self.set_format(tensor_format=self.tensor_format)

    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
106
        model_output: Union[torch.FloatTensor, np.ndarray],
107
108
109
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        eta,
Patrick von Platen's avatar
Patrick von Platen committed
110
        use_clipped_model_output=False,
111
112
        generator=None,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
113
114
115
116
117
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
118
        # - pred_original_sample -> f_theta(x_t, t) or x_0
Patrick von Platen's avatar
Patrick von Platen committed
119
120
        # - std_dev_t -> sigma_t
        # - eta -> η
121
122
        # - pred_sample_direction -> "direction pointingc to x_t"
        # - pred_prev_sample -> "x_t-1"
Patrick von Platen's avatar
Patrick von Platen committed
123

124
        # 1. get previous step value (=t-1)
Nathan Lambert's avatar
Nathan Lambert committed
125
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
Patrick von Platen's avatar
Patrick von Platen committed
126
127

        # 2. compute alphas, betas
128
129
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
130
131
        beta_prod_t = 1 - alpha_prod_t

132
        # 3. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
133
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
134
        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
135
136

        # 4. Clip "predicted x_0"
137
        if self.config.clip_sample:
138
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
139
140
141

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
142
        variance = self._get_variance(timestep, prev_timestep)
Patrick von Platen's avatar
Patrick von Platen committed
143
        std_dev_t = eta * variance ** (0.5)
Patrick von Platen's avatar
Patrick von Platen committed
144

Patrick von Platen's avatar
Patrick von Platen committed
145
146
147
        if use_clipped_model_output:
            # the model_output is always re-derived from the clipped x_0 in Glide
            model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
anton-l's avatar
anton-l committed
148

Patrick von Platen's avatar
Patrick von Platen committed
149
        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Patrick von Platen's avatar
Patrick von Platen committed
150
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
Patrick von Platen's avatar
Patrick von Platen committed
151
152

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
153
154
155
        prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
Patrick von Platen's avatar
Patrick von Platen committed
156
157
            device = model_output.device if torch.is_tensor(model_output) else "cpu"
            noise = torch.randn(model_output.shape, generator=generator).to(device)
158
159
            variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise

Patrick von Platen's avatar
Patrick von Platen committed
160
            if not torch.is_tensor(model_output):
161
162
163
                variance = variance.numpy()

            prev_sample = prev_sample + variance
Patrick von Platen's avatar
Patrick von Platen committed
164

165
        return {"prev_sample": prev_sample}
Patrick von Platen's avatar
Patrick von Platen committed
166

167
168
169
170
171
172
173
174
175
    def add_noise(self, original_samples, noise, timesteps):
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

Patrick von Platen's avatar
Patrick von Platen committed
176
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
177
        return self.config.num_train_timesteps