modeling_ddpm.py 2.96 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


Patrick von Platen's avatar
up  
Patrick von Platen committed
17
import torch
Patrick von Platen's avatar
Patrick von Platen committed
18

Patrick von Platen's avatar
Patrick von Platen committed
19
20
21
import tqdm
from diffusers import DiffusionPipeline

Patrick von Platen's avatar
Patrick von Platen committed
22
23

class DDPM(DiffusionPipeline):
Patrick von Platen's avatar
Patrick von Platen committed
24
    def __init__(self, unet, noise_scheduler):
Patrick von Platen's avatar
up  
Patrick von Platen committed
25
26
27
        super().__init__()
        self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

Patrick von Platen's avatar
Patrick von Platen committed
28
29
30
    def __call__(self, batch_size=1, generator=None, torch_device=None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
up  
Patrick von Platen committed
31
32
33

        self.unet.to(torch_device)
        # 1. Sample gaussian noise
Patrick von Platen's avatar
Patrick von Platen committed
34
35
36
37
38
        image = self.noise_scheduler.sample_noise(
            (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
            device=torch_device,
            generator=generator,
        )
Patrick von Platen's avatar
up  
Patrick von Platen committed
39
40
        for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
            # i) define coefficients for time step t
patil-suraj's avatar
patil-suraj committed
41
42
            clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
            clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
Patrick von Platen's avatar
Patrick von Platen committed
43
44
45
46
47
48
49
50
51
52
            image_coeff = (
                (1 - self.noise_scheduler.get_alpha_prod(t - 1))
                * torch.sqrt(self.noise_scheduler.get_alpha(t))
                / (1 - self.noise_scheduler.get_alpha_prod(t))
            )
            clipped_coeff = (
                torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
                * self.noise_scheduler.get_beta(t)
                / (1 - self.noise_scheduler.get_alpha_prod(t))
            )
Patrick von Platen's avatar
up  
Patrick von Platen committed
53
54
55
56
57
58
59

            # ii) predict noise residual
            with torch.no_grad():
                noise_residual = self.unet(image, t)

            # iii) compute predicted image from residual
            # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
patil-suraj's avatar
patil-suraj committed
60
            pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
Patrick von Platen's avatar
up  
Patrick von Platen committed
61
            pred_mean = torch.clamp(pred_mean, -1, 1)
patil-suraj's avatar
patil-suraj committed
62
            prev_image = clipped_coeff * pred_mean + image_coeff * image
Patrick von Platen's avatar
up  
Patrick von Platen committed
63
64

            # iv) sample variance
Patrick von Platen's avatar
Patrick von Platen committed
65
66
67
            prev_variance = self.noise_scheduler.sample_variance(
                t, prev_image.shape, device=torch_device, generator=generator
            )
Patrick von Platen's avatar
up  
Patrick von Platen committed
68
69
70
71

            # v) sample  x_{t-1} ~ N(prev_image, prev_variance)
            sampled_prev_image = prev_image + prev_variance
            image = sampled_prev_image
Patrick von Platen's avatar
Patrick von Platen committed
72
73

        return image