modeling_ddim.py 2.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

# limitations under the License.


from diffusers import DiffusionPipeline
import tqdm
import torch


def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a


class DDIM(DiffusionPipeline):

    def __init__(self, unet, noise_scheduler):
        super().__init__()
        self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

Patrick von Platen's avatar
Patrick von Platen committed
34
    def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
Patrick von Platen's avatar
Patrick von Platen committed
35
        # eta is η in paper
36
37
38
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"

Patrick von Platen's avatar
Patrick von Platen committed
39
40
        num_trained_timesteps = self.noise_scheduler.num_timesteps
        inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
Patrick von Platen's avatar
Patrick von Platen committed
41

42
43
44
        self.unet.to(torch_device)
        x = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)

Patrick von Platen's avatar
Patrick von Platen committed
45
46
47
48
49
50
51
52
        b = self.noise_scheduler.betas.to(torch_device)

        seq = inference_step_times
        seq_next = [-1] + list(seq[:-1])
#        for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
#            train_step = inference_step_times[t]
        for i, j in zip(reversed(seq), reversed(seq_next)):

53
54
55
            n = batch_size
            x0_preds = []
            xs = [x]
Patrick von Platen's avatar
Patrick von Platen committed
56
57
58
59

#            i = train_step
#            j = inference_step_times[t-1] if t > 0 else -1
            if True:
60
61
62
63
64
65
                print(i)
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())
                xt = xs[-1].to('cuda')
Patrick von Platen's avatar
Patrick von Platen committed
66
67
                with torch.no_grad():
                    et = self.unet(xt, t)
68
69
70
71
72
73
74
75
76
77
                x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
                x0_preds.append(x0_t.to('cpu'))
                # eta
                c1 = (
                    eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
                )
                c2 = ((1 - at_next) - c1 ** 2).sqrt()
                xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
                xs.append(xt_next.to('cpu'))

Patrick von Platen's avatar
Patrick von Platen committed
78
        return xt_next