pipeline_stable_diffusion.py 5.95 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
7
8
9
10
import inspect
from typing import List, Optional, Union

import torch

from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
11
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
Suraj Patil's avatar
Suraj Patil committed
12
13
14
15
16
17
18
19
20


class StableDiffusionPipeline(DiffusionPipeline):
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
21
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
Suraj Patil's avatar
Suraj Patil committed
22
23
24
25
26
27
28
29
30
    ):
        super().__init__()
        scheduler = scheduler.set_format("pt")
        self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]],
31
32
        height: Optional[int] = 512,
        width: Optional[int] = 512,
Suraj Patil's avatar
Suraj Patil committed
33
        num_inference_steps: Optional[int] = 50,
34
        guidance_scale: Optional[float] = 7.5,
Suraj Patil's avatar
Suraj Patil committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        eta: Optional[float] = 0.0,
        generator: Optional[torch.Generator] = None,
        torch_device: Optional[Union[str, torch.device]] = None,
        output_type: Optional[str] = "pil",
    ):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"

        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

50
51
52
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Suraj Patil's avatar
Suraj Patil committed
53
54
55
56
57
        self.unet.to(torch_device)
        self.vae.to(torch_device)
        self.text_encoder.to(torch_device)

        # get prompt text embeddings
58
59
60
61
62
63
64
        text_input = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
Suraj Patil's avatar
Suraj Patil committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            max_length = text_input.input_ids.shape[-1]
            uncond_input = self.tokenizer(
                [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
            )
            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # get the intial random noise
        latents = torch.randn(
86
            (batch_size, self.unet.in_channels, height // 8, width // 8),
Suraj Patil's avatar
Suraj Patil committed
87
            generator=generator,
88
            device=torch_device,
Suraj Patil's avatar
Suraj Patil committed
89
        )
90
91
92
93
94
95
96
97

        # set timesteps
        accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
        extra_set_kwargs = {}
        if accepts_offset:
            extra_set_kwargs["offset"] = 1

        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
Suraj Patil's avatar
Suraj Patil committed
98

99
100
101
102
        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
        if isinstance(self.scheduler, LMSDiscreteScheduler):
            latents = latents * self.scheduler.sigmas[0]

Suraj Patil's avatar
Suraj Patil committed
103
104
105
106
107
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (畏) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to 畏 in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
108
        extra_step_kwargs = {}
Suraj Patil's avatar
Suraj Patil committed
109
        if accepts_eta:
110
            extra_step_kwargs["eta"] = eta
Suraj Patil's avatar
Suraj Patil committed
111

112
        for i, t in tqdm(enumerate(self.scheduler.timesteps)):
Suraj Patil's avatar
Suraj Patil committed
113
114
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
115
116
117
            if isinstance(self.scheduler, LMSDiscreteScheduler):
                sigma = self.scheduler.sigmas[i]
                latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
Suraj Patil's avatar
Suraj Patil committed
118
119
120
121
122
123
124
125
126
127

            # predict the noise residual
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
128
129
130
131
            if isinstance(self.scheduler, LMSDiscreteScheduler):
                latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
            else:
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
Suraj Patil's avatar
Suraj Patil committed
132
133
134
135
136
137
138
139
140
141
142

        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents)

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        return {"sample": image}