modeling_ddpm.py 2.8 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


from diffusers import DiffusionPipeline
Patrick von Platen's avatar
up  
Patrick von Platen committed
18
19
import tqdm
import torch
Patrick von Platen's avatar
Patrick von Platen committed
20
21
22


class DDPM(DiffusionPipeline):
23
24
25

    modeling_file = "modeling_ddpm.py"

Patrick von Platen's avatar
Patrick von Platen committed
26
    def __init__(self, unet, noise_scheduler):
Patrick von Platen's avatar
up  
Patrick von Platen committed
27
28
29
        super().__init__()
        self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

Patrick von Platen's avatar
Patrick von Platen committed
30
31
32
    def __call__(self, batch_size=1, generator=None, torch_device=None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
up  
Patrick von Platen committed
33
34
35

        self.unet.to(torch_device)
        # 1. Sample gaussian noise
Patrick von Platen's avatar
up  
Patrick von Platen committed
36
        image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
Patrick von Platen's avatar
up  
Patrick von Platen committed
37
38
        for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
            # i) define coefficients for time step t
patil-suraj's avatar
patil-suraj committed
39
40
            clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
            clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
Patrick von Platen's avatar
up  
Patrick von Platen committed
41
            image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
patil-suraj's avatar
patil-suraj committed
42
            clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
Patrick von Platen's avatar
up  
Patrick von Platen committed
43
44
45
46
47
48
49

            # ii) predict noise residual
            with torch.no_grad():
                noise_residual = self.unet(image, t)

            # iii) compute predicted image from residual
            # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
patil-suraj's avatar
patil-suraj committed
50
            pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
Patrick von Platen's avatar
up  
Patrick von Platen committed
51
            pred_mean = torch.clamp(pred_mean, -1, 1)
patil-suraj's avatar
patil-suraj committed
52
            prev_image = clipped_coeff * pred_mean + image_coeff * image
Patrick von Platen's avatar
up  
Patrick von Platen committed
53
54
55
56
57
58
59

            # iv) sample variance
            prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)

            # v) sample  x_{t-1} ~ N(prev_image, prev_variance)
            sampled_prev_image = prev_image + prev_variance
            image = sampled_prev_image
Patrick von Platen's avatar
Patrick von Platen committed
60
61

        return image