pipeline_stable_diffusion.py 6.62 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
import inspect
from typing import List, Optional, Union

import torch

from tqdm.auto import tqdm
Suraj Patil's avatar
Suraj Patil committed
7
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Suraj Patil's avatar
Suraj Patil committed
8
9
10

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
11
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
Suraj Patil's avatar
Suraj Patil committed
12
from .safety_checker import StableDiffusionSafetyChecker
Suraj Patil's avatar
Suraj Patil committed
13
14
15
16
17
18
19
20
21


class StableDiffusionPipeline(DiffusionPipeline):
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
22
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
Suraj Patil's avatar
Suraj Patil committed
23
24
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPFeatureExtractor,
Suraj Patil's avatar
Suraj Patil committed
25
26
27
    ):
        super().__init__()
        scheduler = scheduler.set_format("pt")
Suraj Patil's avatar
Suraj Patil committed
28
29
30
31
32
33
34
35
36
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
Suraj Patil's avatar
Suraj Patil committed
37
38
39
40
41

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]],
42
43
        height: Optional[int] = 512,
        width: Optional[int] = 512,
Suraj Patil's avatar
Suraj Patil committed
44
        num_inference_steps: Optional[int] = 50,
45
        guidance_scale: Optional[float] = 7.5,
Suraj Patil's avatar
Suraj Patil committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        eta: Optional[float] = 0.0,
        generator: Optional[torch.Generator] = None,
        torch_device: Optional[Union[str, torch.device]] = None,
        output_type: Optional[str] = "pil",
    ):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"

        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

61
62
63
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Suraj Patil's avatar
Suraj Patil committed
64
65
66
        self.unet.to(torch_device)
        self.vae.to(torch_device)
        self.text_encoder.to(torch_device)
Suraj Patil's avatar
Suraj Patil committed
67
        self.safety_checker.to(torch_device)
Suraj Patil's avatar
Suraj Patil committed
68
69

        # get prompt text embeddings
70
71
72
73
74
75
76
        text_input = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
Suraj Patil's avatar
Suraj Patil committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            max_length = text_input.input_ids.shape[-1]
            uncond_input = self.tokenizer(
                [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
            )
            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

        # get the intial random noise
        latents = torch.randn(
98
            (batch_size, self.unet.in_channels, height // 8, width // 8),
Suraj Patil's avatar
Suraj Patil committed
99
            generator=generator,
100
            device=torch_device,
Suraj Patil's avatar
Suraj Patil committed
101
        )
102
103
104
105
106
107
108
109

        # set timesteps
        accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
        extra_set_kwargs = {}
        if accepts_offset:
            extra_set_kwargs["offset"] = 1

        self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
Suraj Patil's avatar
Suraj Patil committed
110

111
112
113
114
        # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
        if isinstance(self.scheduler, LMSDiscreteScheduler):
            latents = latents * self.scheduler.sigmas[0]

Suraj Patil's avatar
Suraj Patil committed
115
116
117
118
119
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (畏) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to 畏 in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
120
        extra_step_kwargs = {}
Suraj Patil's avatar
Suraj Patil committed
121
        if accepts_eta:
122
            extra_step_kwargs["eta"] = eta
Suraj Patil's avatar
Suraj Patil committed
123

124
        for i, t in tqdm(enumerate(self.scheduler.timesteps)):
Suraj Patil's avatar
Suraj Patil committed
125
126
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
127
128
129
            if isinstance(self.scheduler, LMSDiscreteScheduler):
                sigma = self.scheduler.sigmas[i]
                latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
Suraj Patil's avatar
Suraj Patil committed
130
131
132
133
134
135
136
137
138
139

            # predict the noise residual
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
140
141
142
143
            if isinstance(self.scheduler, LMSDiscreteScheduler):
                latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
            else:
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
Suraj Patil's avatar
Suraj Patil committed
144
145
146
147
148
149
150

        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents)

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
Suraj Patil's avatar
Suraj Patil committed
151
152
153
154
155

        # run safety checker
        safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device)
        image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)

Suraj Patil's avatar
Suraj Patil committed
156
157
158
        if output_type == "pil":
            image = self.numpy_to_pil(image)

Suraj Patil's avatar
Suraj Patil committed
159
        return {"sample": image, "nsfw_content_detected": has_nsfw_concept}