scheduling_ddpm.py 8.29 KB
Newer Older
1
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
18
from typing import Optional, Tuple, Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
import numpy as np
21
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
22

23
from ..configuration_utils import ConfigMixin, register_to_config
24
from .scheduling_utils import SchedulerMixin, SchedulerOutput
25
26
27
28


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
31

Patrick von Platen's avatar
Patrick von Platen committed
32
33
34
    :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
    from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that part of the diffusion process.
35
36
37
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
38

39
40
41
42
43
44
45
46
47
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
48
49


Patrick von Platen's avatar
Patrick von Platen committed
50
class DDPMScheduler(SchedulerMixin, ConfigMixin):
51
    @register_to_config
Patrick von Platen's avatar
improve  
Patrick von Platen committed
52
53
    def __init__(
        self,
Partho's avatar
Partho committed
54
55
56
57
58
59
60
61
        num_train_timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        beta_schedule: str = "linear",
        trained_betas: Optional[np.ndarray] = None,
        variance_type: str = "fixed_small",
        clip_sample: bool = True,
        tensor_format: str = "pt",
Patrick von Platen's avatar
improve  
Patrick von Platen committed
62
63
    ):

64
65
66
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
        elif beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
67
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
68
69
70
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model.
            self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
anton-l's avatar
anton-l committed
71
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
72
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
73
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
74
75
76
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
77
78
79
80
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

81
82
83
84
85
        # setable values
        self.num_inference_steps = None
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

        self.tensor_format = tensor_format
Patrick von Platen's avatar
Patrick von Platen committed
86
87
        self.set_format(tensor_format=tensor_format)

88
89
        self.variance_type = variance_type

Partho's avatar
Partho committed
90
    def set_timesteps(self, num_inference_steps: int):
Patrick von Platen's avatar
Patrick von Platen committed
91
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
92
93
94
95
96
97
        self.num_inference_steps = num_inference_steps
        self.timesteps = np.arange(
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
        )[::-1].copy()
        self.set_format(tensor_format=self.tensor_format)

98
    def _get_variance(self, t, predicted_variance=None, variance_type=None):
99
100
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
101

Kashif Rasul's avatar
Kashif Rasul committed
102
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
103
        # and sample from it to get previous sample
Kashif Rasul's avatar
Kashif Rasul committed
104
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
105
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
Patrick von Platen's avatar
Patrick von Platen committed
106

107
108
109
        if variance_type is None:
            variance_type = self.config.variance_type

Patrick von Platen's avatar
Patrick von Platen committed
110
        # hacks - were probs added for training stability
111
        if variance_type == "fixed_small":
Patrick von Platen's avatar
Patrick von Platen committed
112
            variance = self.clip(variance, min_value=1e-20)
113
        # for rl-diffuser https://arxiv.org/abs/2205.09991
114
        elif variance_type == "fixed_small_log":
115
            variance = self.log(self.clip(variance, min_value=1e-20))
116
        elif variance_type == "fixed_large":
117
            variance = self.betas[t]
118
        elif variance_type == "fixed_large_log":
Patrick von Platen's avatar
Patrick von Platen committed
119
            # Glide max_log
120
            variance = self.log(self.betas[t])
121
122
123
124
125
126
127
        elif variance_type == "learned":
            return predicted_variance
        elif variance_type == "learned_range":
            min_log = variance
            max_log = self.betas[t]
            frac = (predicted_variance + 1) / 2
            variance = frac * max_log + (1 - frac) * min_log
Patrick von Platen's avatar
Patrick von Platen committed
128
129
130

        return variance

131
132
    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
133
        model_output: Union[torch.FloatTensor, np.ndarray],
134
135
136
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        predict_epsilon=True,
Patrick von Platen's avatar
Patrick von Platen committed
137
        generator=None,
138
139
140
        return_dict: bool = True,
    ) -> Union[SchedulerOutput, Tuple]:

141
        t = timestep
142

143
144
145
146
147
        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
            model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
        else:
            predicted_variance = None

Patrick von Platen's avatar
Patrick von Platen committed
148
        # 1. compute alphas, betas
149
150
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
151
152
153
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

154
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
155
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
156
        if predict_epsilon:
Patrick von Platen's avatar
Patrick von Platen committed
157
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
158
        else:
Patrick von Platen's avatar
Patrick von Platen committed
159
            pred_original_sample = model_output
Patrick von Platen's avatar
Patrick von Platen committed
160
161

        # 3. Clip "predicted x_0"
162
        if self.config.clip_sample:
163
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
164

165
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
166
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
167
168
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
        current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
169

170
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
171
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
172
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
173

Patrick von Platen's avatar
Patrick von Platen committed
174
175
176
        # 6. Add noise
        variance = 0
        if t > 0:
177
            noise = self.randn_like(model_output, generator=generator)
178
            variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
Patrick von Platen's avatar
Patrick von Platen committed
179
180
181

        pred_prev_sample = pred_prev_sample + variance

182
183
184
185
        if not return_dict:
            return (pred_prev_sample,)

        return SchedulerOutput(prev_sample=pred_prev_sample)
Patrick von Platen's avatar
Patrick von Platen committed
186

Partho's avatar
Partho committed
187
188
189
190
191
192
193
    def add_noise(
        self,
        original_samples: Union[torch.FloatTensor, np.ndarray],
        noise: Union[torch.FloatTensor, np.ndarray],
        timesteps: Union[torch.IntTensor, np.ndarray],
    ) -> Union[torch.FloatTensor, np.ndarray]:

anton-l's avatar
anton-l committed
194
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
195
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
anton-l's avatar
anton-l committed
196
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
197
198
199
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
anton-l's avatar
anton-l committed
200
        return noisy_samples
anton-l's avatar
anton-l committed
201

Patrick von Platen's avatar
improve  
Patrick von Platen committed
202
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
203
        return self.config.num_train_timesteps