modeling_ddpm.py 2.76 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


from diffusers import DiffusionPipeline
Patrick von Platen's avatar
up  
Patrick von Platen committed
18
19
import tqdm
import torch
Patrick von Platen's avatar
Patrick von Platen committed
20
21
22


class DDPM(DiffusionPipeline):
23

Patrick von Platen's avatar
Patrick von Platen committed
24
    def __init__(self, unet, noise_scheduler):
Patrick von Platen's avatar
up  
Patrick von Platen committed
25
26
27
        super().__init__()
        self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

Patrick von Platen's avatar
Patrick von Platen committed
28
29
30
    def __call__(self, batch_size=1, generator=None, torch_device=None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Patrick von Platen's avatar
up  
Patrick von Platen committed
31
32
33

        self.unet.to(torch_device)
        # 1. Sample gaussian noise
Patrick von Platen's avatar
up  
Patrick von Platen committed
34
        image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
Patrick von Platen's avatar
up  
Patrick von Platen committed
35
36
        for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
            # i) define coefficients for time step t
patil-suraj's avatar
patil-suraj committed
37
38
            clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
            clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
Patrick von Platen's avatar
up  
Patrick von Platen committed
39
            image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
patil-suraj's avatar
patil-suraj committed
40
            clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
Patrick von Platen's avatar
up  
Patrick von Platen committed
41
42
43
44
45
46
47

            # ii) predict noise residual
            with torch.no_grad():
                noise_residual = self.unet(image, t)

            # iii) compute predicted image from residual
            # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison
patil-suraj's avatar
patil-suraj committed
48
            pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual
Patrick von Platen's avatar
up  
Patrick von Platen committed
49
            pred_mean = torch.clamp(pred_mean, -1, 1)
patil-suraj's avatar
patil-suraj committed
50
            prev_image = clipped_coeff * pred_mean + image_coeff * image
Patrick von Platen's avatar
up  
Patrick von Platen committed
51
52
53
54
55
56
57

            # iv) sample variance
            prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)

            # v) sample  x_{t-1} ~ N(prev_image, prev_variance)
            sampled_prev_image = prev_image + prev_variance
            image = sampled_prev_image
Patrick von Platen's avatar
Patrick von Platen committed
58
59

        return image