scheduling_ddpm.py 7.06 KB
Newer Older
1
# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
Patrick von Platen's avatar
improve  
Patrick von Platen committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
15
16

# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

anton-l's avatar
anton-l committed
17
import math
18
from typing import Union
Patrick von Platen's avatar
Patrick von Platen committed
19

Patrick von Platen's avatar
Patrick von Platen committed
20
import numpy as np
21
import torch
Patrick von Platen's avatar
improve  
Patrick von Platen committed
22

Patrick von Platen's avatar
Patrick von Platen committed
23
from ..configuration_utils import ConfigMixin
24
25
26
27
28
from .scheduling_utils import SchedulerMixin


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
    """
Patrick von Platen's avatar
Patrick von Platen committed
29
30
    Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
    (1-beta) over time from t = [0,1].
31

Patrick von Platen's avatar
Patrick von Platen committed
32
33
34
    :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
    from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that part of the diffusion process.
35
36
37
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
38

39
40
41
42
43
44
45
46
47
    def alpha_bar(time_step):
        return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas, dtype=np.float32)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
48
49


Patrick von Platen's avatar
Patrick von Platen committed
50
class DDPMScheduler(SchedulerMixin, ConfigMixin):
Patrick von Platen's avatar
improve  
Patrick von Platen committed
51
52
    def __init__(
        self,
Nathan Lambert's avatar
Nathan Lambert committed
53
        num_train_timesteps=1000,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
54
55
56
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
57
58
        trained_betas=None,
        timestep_values=None,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
59
        variance_type="fixed_small",
Patrick von Platen's avatar
Patrick von Platen committed
60
        clip_sample=True,
Patrick von Platen's avatar
Patrick von Platen committed
61
        tensor_format="np",
Patrick von Platen's avatar
improve  
Patrick von Platen committed
62
63
    ):
        super().__init__()
64
        self.register_to_config(
Nathan Lambert's avatar
Nathan Lambert committed
65
            num_train_timesteps=num_train_timesteps,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
66
67
68
            beta_start=beta_start,
            beta_end=beta_end,
            beta_schedule=beta_schedule,
69
70
            trained_betas=trained_betas,
            timestep_values=timestep_values,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
71
            variance_type=variance_type,
Patrick von Platen's avatar
Patrick von Platen committed
72
            clip_sample=clip_sample,
Patrick von Platen's avatar
improve  
Patrick von Platen committed
73
74
        )

75
76
77
        if trained_betas is not None:
            self.betas = np.asarray(trained_betas)
        elif beta_schedule == "linear":
Nathan Lambert's avatar
Nathan Lambert committed
78
            self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
anton-l's avatar
anton-l committed
79
        elif beta_schedule == "squaredcos_cap_v2":
Patrick von Platen's avatar
Patrick von Platen committed
80
            # Glide cosine schedule
Nathan Lambert's avatar
Nathan Lambert committed
81
            self.betas = betas_for_alpha_bar(num_train_timesteps)
Patrick von Platen's avatar
improve  
Patrick von Platen committed
82
83
84
        else:
            raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Patrick von Platen's avatar
Patrick von Platen committed
85
86
87
88
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
        self.one = np.array(1.0)

89
90
91
92
93
        # setable values
        self.num_inference_steps = None
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

        self.tensor_format = tensor_format
Patrick von Platen's avatar
Patrick von Platen committed
94
95
        self.set_format(tensor_format=tensor_format)

96
    def set_timesteps(self, num_inference_steps):
Patrick von Platen's avatar
Patrick von Platen committed
97
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
98
99
100
101
102
103
        self.num_inference_steps = num_inference_steps
        self.timesteps = np.arange(
            0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
        )[::-1].copy()
        self.set_format(tensor_format=self.tensor_format)

104
    def get_variance(self, t, variance_type=None):
105
106
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
107

Kashif Rasul's avatar
Kashif Rasul committed
108
        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
109
        # and sample from it to get previous sample
Kashif Rasul's avatar
Kashif Rasul committed
110
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
111
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
Patrick von Platen's avatar
Patrick von Platen committed
112

113
114
115
        if variance_type is None:
            variance_type = self.config.variance_type

Patrick von Platen's avatar
Patrick von Platen committed
116
        # hacks - were probs added for training stability
117
        if variance_type == "fixed_small":
Patrick von Platen's avatar
Patrick von Platen committed
118
            variance = self.clip(variance, min_value=1e-20)
119
        # for rl-diffuser https://arxiv.org/abs/2205.09991
120
        elif variance_type == "fixed_small_log":
121
            variance = self.log(self.clip(variance, min_value=1e-20))
122
        elif variance_type == "fixed_large":
123
            variance = self.betas[t]
124
        elif variance_type == "fixed_large_log":
Patrick von Platen's avatar
Patrick von Platen committed
125
            # Glide max_log
126
            variance = self.log(self.betas[t])
Patrick von Platen's avatar
Patrick von Platen committed
127
128
129

        return variance

130
131
    def step(
        self,
Patrick von Platen's avatar
Patrick von Platen committed
132
        model_output: Union[torch.FloatTensor, np.ndarray],
133
134
135
136
137
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        predict_epsilon=True,
    ):
        t = timestep
Patrick von Platen's avatar
Patrick von Platen committed
138
        # 1. compute alphas, betas
139
140
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
Patrick von Platen's avatar
Patrick von Platen committed
141
142
143
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

144
        # 2. compute predicted original sample from predicted noise also called
Patrick von Platen's avatar
Patrick von Platen committed
145
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
146
        if predict_epsilon:
Patrick von Platen's avatar
Patrick von Platen committed
147
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
148
        else:
Patrick von Platen's avatar
Patrick von Platen committed
149
            pred_original_sample = model_output
Patrick von Platen's avatar
Patrick von Platen committed
150
151

        # 3. Clip "predicted x_0"
152
        if self.config.clip_sample:
153
            pred_original_sample = self.clip(pred_original_sample, -1, 1)
Patrick von Platen's avatar
Patrick von Platen committed
154

155
        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
Patrick von Platen's avatar
Patrick von Platen committed
156
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
157
158
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
        current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
Patrick von Platen's avatar
Patrick von Platen committed
159

160
        # 5. Compute predicted previous sample µ_t
Patrick von Platen's avatar
Patrick von Platen committed
161
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
162
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
Patrick von Platen's avatar
Patrick von Platen committed
163

164
        return {"prev_sample": pred_prev_sample}
Patrick von Platen's avatar
Patrick von Platen committed
165

166
    def add_noise(self, original_samples, noise, timesteps):
anton-l's avatar
anton-l committed
167
        sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
168
        sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
anton-l's avatar
anton-l committed
169
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
170
171
172
        sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)

        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
anton-l's avatar
anton-l committed
173
        return noisy_samples
anton-l's avatar
anton-l committed
174

Patrick von Platen's avatar
improve  
Patrick von Platen committed
175
    def __len__(self):
Nathan Lambert's avatar
Nathan Lambert committed
176
        return self.config.num_train_timesteps