modeling_ddpm.py 4.1 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


Patrick von Platen's avatar
up  
Patrick von Platen committed
17
import torch
Patrick von Platen's avatar
Patrick von Platen committed
18

Patrick von Platen's avatar
Patrick von Platen committed
19
20
21
import tqdm
from diffusers import DiffusionPipeline

Patrick von Platen's avatar
Patrick von Platen committed
22
23

class DDPM(DiffusionPipeline):
Patrick von Platen's avatar
Patrick von Platen committed
24
    def __init__(self, unet, noise_scheduler):
Patrick von Platen's avatar
up  
Patrick von Platen committed
25
26
27
        super().__init__()
        self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

Patrick von Platen's avatar
Patrick von Platen committed
28
29
30
    def __call__(self, batch_size=1, generator=None, torch_device=None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
up  
Patrick von Platen committed
31
32

        self.unet.to(torch_device)
Patrick von Platen's avatar
Patrick von Platen committed
33
34

        # Sample gaussian noise to begin loop
Patrick von Platen's avatar
Patrick von Platen committed
35
36
37
38
39
        image = self.noise_scheduler.sample_noise(
            (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
            device=torch_device,
            generator=generator,
        )
Patrick von Platen's avatar
up  
Patrick von Platen committed
40

Patrick von Platen's avatar
Patrick von Platen committed
41
42
        for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
            # 1. predict noise residual
Patrick von Platen's avatar
up  
Patrick von Platen committed
43
            with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
44
                pred_noise_t = self.unet(image, t)
Patrick von Platen's avatar
up  
Patrick von Platen committed
45

Patrick von Platen's avatar
Patrick von Platen committed
46
47
48
            # 2. compute alphas, betas
            alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
            alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
Patrick von Platen's avatar
Patrick von Platen committed
49
50
            beta_prod_t = 1 - alpha_prod_t
            beta_prod_t_prev = 1 - alpha_prod_t_prev
Patrick von Platen's avatar
Patrick von Platen committed
51
52

            # 3. compute predicted image from residual
Patrick von Platen's avatar
Patrick von Platen committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            # First: compute predicted original image from predicted noise also called
            # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
            pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()

            # Second: Clip "predicted x_0"
            pred_original_image = torch.clamp(pred_original_image, -1, 1)

            # Third: Compute coefficients for pred_original_image x_0 and current image x_t
            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
            current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
            # Fourth: Compute predicted previous image µ_t
            # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image

            # 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
            # and sample from it to get previous image
            # x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
            if t > 0:
Patrick von Platen's avatar
Patrick von Platen committed
72
73
74
75
76
77
                variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
                # TODO(PVP):
                # This variance seems to be good enough for inference - check if those `fix_small`, `fix_large`
                # are really only needed for training or also for inference
                # Also note LDM only uses "fixed_small";
                # glide seems to use a weird mix of the two: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
Patrick von Platen's avatar
Patrick von Platen committed
78
79
80
81
82
83
84
85
                noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
                sampled_variance = variance * noise
                prev_image = pred_prev_image + sampled_variance
            else:
                prev_image = pred_prev_image

            # 6. Set current image to prev_image: x_t -> x_t-1
            image = prev_image
Patrick von Platen's avatar
Patrick von Platen committed
86
87

        return image