modeling_latent_diffusion.py 4.32 KB
Newer Older
1
import tqdm
patil-suraj's avatar
patil-suraj committed
2
3
import torch

patil-suraj's avatar
patil-suraj committed
4
5
from diffusers import DiffusionPipeline

patil-suraj's avatar
patil-suraj committed
6
7
8
9
# add these relative imports here, so we can load from hub
from .modeling_vae import AutoencoderKL # NOQA
from .configuration_ldmbert import LDMBertConfig # NOQA
from .modeling_ldmbert import LDMBertModel # NOQA
patil-suraj's avatar
patil-suraj committed
10
11
12
13
14
15

class LatentDiffusion(DiffusionPipeline):
    def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
        super().__init__()
        self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)

patil-suraj's avatar
patil-suraj committed
16
    @torch.no_grad()
17
    def __call__(self, prompt, batch_size=1, generator=None, torch_device=None, eta=0.0, guidance_scale=1.0, num_inference_steps=50):
patil-suraj's avatar
patil-suraj committed
18
19
20
21
22
23
24
25
        # eta corresponds to η in paper and should be between [0, 1]

        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"

        self.unet.to(torch_device)
        self.vqvae.to(torch_device)
        self.bert.to(torch_device)
26
        
patil-suraj's avatar
patil-suraj committed
27
        # get unconditional embeddings for classifier free guidence
28
29
30
31
        if guidance_scale != 1.0:
            uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
            uncond_embeddings = self.bert(uncond_input.input_ids)[0]
        
patil-suraj's avatar
patil-suraj committed
32
33
        # get text embedding
        text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors='pt').to(torch_device)
patil-suraj's avatar
patil-suraj committed
34
        text_embedding = self.bert(text_input.input_ids)[0]
patil-suraj's avatar
patil-suraj committed
35
36
37
38
39
        
        num_trained_timesteps = self.noise_scheduler.num_timesteps
        inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)

        image = self.noise_scheduler.sample_noise(
patil-suraj's avatar
patil-suraj committed
40
            (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
patil-suraj's avatar
patil-suraj committed
41
42
43
            device=torch_device,
            generator=generator,
        )
44
45
46
47
48
49
50
51
52
53
54
        
        # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
        # Ideally, read DDIM paper in-detail understanding

        # Notation (<variable name> -> <name in paper>
        # - pred_noise_t -> e_theta(x_t, t)
        # - pred_original_image -> f_theta(x_t, t) or x_0
        # - std_dev_t -> sigma_t
        # - eta -> η
        # - pred_image_direction -> "direction pointingc to x_t"
        # - pred_prev_image -> "x_t-1"
patil-suraj's avatar
patil-suraj committed
55
        for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
patil-suraj's avatar
patil-suraj committed
56
            # guidance_scale of 1 means no guidance
57
58
            if guidance_scale == 1.0:
                image_in = image
patil-suraj's avatar
patil-suraj committed
59
60
                context = text_embedding
                timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
61
            else:
patil-suraj's avatar
patil-suraj committed
62
63
64
                # for classifier free guidance, we need to do two forward passes
                # here we concanate embedding and unconditioned embedding in a single batch 
                # to avoid doing two forward passes
65
66
                image_in = torch.cat([image] * 2)
                context = torch.cat([uncond_embeddings, text_embedding])
patil-suraj's avatar
patil-suraj committed
67
68
69
70
                timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)

            # 1. predict noise residual
            pred_noise_t = self.unet(image_in, timesteps, context=context)
71
            
patil-suraj's avatar
patil-suraj committed
72
            # perform guidance
73
74
75
76
            if guidance_scale != 1.0:
                pred_noise_t_uncond, pred_noise_t = pred_noise_t.chunk(2)
                pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
                    
patil-suraj's avatar
patil-suraj committed
77
78
79
80
81
82
            # 2. predict previous mean of image x_t-1
            pred_prev_image = self.noise_scheduler.compute_prev_image_step(pred_noise_t, image, t, num_inference_steps, eta)

            # 3. optionally sample variance
            variance = 0
            if eta > 0:
patil-suraj's avatar
patil-suraj committed
83
                noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
patil-suraj's avatar
patil-suraj committed
84
                variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
85

patil-suraj's avatar
patil-suraj committed
86
87
            # 4. set current image to prev_image: x_t -> x_t-1
            image = pred_prev_image + variance
88

patil-suraj's avatar
patil-suraj committed
89
        # scale and decode image with vae
patil-suraj's avatar
patil-suraj committed
90
91
        image = 1 /  0.18215 * image
        image = self.vqvae.decode(image)
patil-suraj's avatar
patil-suraj committed
92
93
94
        image = torch.clamp((image+1.0)/2.0, min=0.0, max=1.0)

        return image