pipeline_stable_diffusion.py 14.5 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
import inspect
Pedro Cuenca's avatar
Pedro Cuenca committed
2
import warnings
Suraj Patil's avatar
Suraj Patil committed
3
4
5
6
from typing import List, Optional, Union

import torch

Suraj Patil's avatar
Suraj Patil committed
7
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
Suraj Patil's avatar
Suraj Patil committed
8

9
from ...configuration_utils import FrozenDict
Suraj Patil's avatar
Suraj Patil committed
10
11
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
12
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13
from . import StableDiffusionPipelineOutput
Suraj Patil's avatar
Suraj Patil committed
14
from .safety_checker import StableDiffusionSafetyChecker
Suraj Patil's avatar
Suraj Patil committed
15
16
17


class StableDiffusionPipeline(DiffusionPipeline):
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
39
            Classification module that estimates whether generated images could be considered offensive or harmful.
40
41
42
43
44
            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
        feature_extractor ([`CLIPFeatureExtractor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """

Suraj Patil's avatar
Suraj Patil committed
45
46
47
48
49
50
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
51
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
Suraj Patil's avatar
Suraj Patil committed
52
53
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPFeatureExtractor,
Suraj Patil's avatar
Suraj Patil committed
54
55
56
    ):
        super().__init__()
        scheduler = scheduler.set_format("pt")
57
58
59
60

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
            warnings.warn(
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
Yuta Hayashibe's avatar
Yuta Hayashibe committed
61
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
62
63
64
65
66
67
68
69
70
71
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file",
                DeprecationWarning,
            )
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

Suraj Patil's avatar
Suraj Patil committed
72
73
74
75
76
77
78
79
80
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
Suraj Patil's avatar
Suraj Patil committed
81

82
    def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
83
84
85
        r"""
        Enable sliced attention computation.

Pedro Cuenca's avatar
Pedro Cuenca committed
86
87
        When this option is enabled, the attention module will split the input tensor in slices, to compute attention
        in several steps. This is useful to save some memory in exchange for a small speed decrease.
88
89
90

        Args:
            slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
Pedro Cuenca's avatar
Pedro Cuenca committed
91
92
                When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
                a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
93
94
                `attention_head_dim` must be a multiple of `slice_size`.
        """
95
96
97
98
99
100
101
        if slice_size == "auto":
            # half the attention head size is usually a good trade-off between
            # speed and memory
            slice_size = self.unet.config.attention_head_dim // 2
        self.unet.set_attention_slice(slice_size)

    def disable_attention_slicing(self):
102
103
104
105
        r"""
        Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
        back to computing attention in one step.
        """
Patrick von Platen's avatar
Patrick von Platen committed
106
107
        # set slice_size = `None` to disable `attention slicing`
        self.enable_attention_slicing(None)
108

Suraj Patil's avatar
Suraj Patil committed
109
110
111
112
    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]],
113
114
        height: Optional[int] = 512,
        width: Optional[int] = 512,
Suraj Patil's avatar
Suraj Patil committed
115
        num_inference_steps: Optional[int] = 50,
116
        guidance_scale: Optional[float] = 7.5,
Suraj Patil's avatar
Suraj Patil committed
117
118
        eta: Optional[float] = 0.0,
        generator: Optional[torch.Generator] = None,
119
        latents: Optional[torch.FloatTensor] = None,
Suraj Patil's avatar
Suraj Patil committed
120
        output_type: Optional[str] = "pil",
121
        return_dict: bool = True,
Pedro Cuenca's avatar
Pedro Cuenca committed
122
        **kwargs,
Suraj Patil's avatar
Suraj Patil committed
123
    ):
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        r"""
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`):
                The prompt or prompts to guide the image generation.
            height (`int`, *optional*, defaults to 512):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to 512):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
                [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator`, *optional*):
                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
                deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will ge generated by sampling using the supplied random `generator`.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
155
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
156
157
158
159
160
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.

        Returns:
161
162
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
163
164
165
166
167
            When returning a tuple, the first element is a list with the generated images, and the second element is a
            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
            (nsfw) content, according to the `safety_checker`.
        """

Pedro Cuenca's avatar
Pedro Cuenca committed
168
169
170
171
172
173
174
175
176
177
178
        if "torch_device" in kwargs:
            device = kwargs.pop("torch_device")
            warnings.warn(
                "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
                " Consider using `pipe.to(torch_device)` instead."
            )

            # Set device as before (to be removed in 0.3.0)
            if device is None:
                device = "cuda" if torch.cuda.is_available() else "cpu"
            self.to(device)
Suraj Patil's avatar
Suraj Patil committed
179
180
181
182
183
184
185
186

        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

187
188
189
        if height % 8 != 0 or width % 8 != 0:
            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

Suraj Patil's avatar
Suraj Patil committed
190
        # get prompt text embeddings
191
192
193
194
195
196
197
        text_input = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
Pedro Cuenca's avatar
Pedro Cuenca committed
198
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
Suraj Patil's avatar
Suraj Patil committed
199
200
201
202
203
204
205
206
207
208
209

        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0
        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance:
            max_length = text_input.input_ids.shape[-1]
            uncond_input = self.tokenizer(
                [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
            )
Pedro Cuenca's avatar
Pedro Cuenca committed
210
            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
Suraj Patil's avatar
Suraj Patil committed
211
212
213
214
215
216

            # For classifier free guidance, we need to do two forward passes.
            # Here we concatenate the unconditional and text embeddings into a single batch
            # to avoid doing two forward passes
            text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

217
        # get the initial random noise unless the user supplied it
218
219
220
221
222

        # Unlike in other pipelines, latents need to be generated in the target device
        # for 1-to-1 results reproducibility with the CompVis implementation.
        # However this currently doesn't work in `mps`.
        latents_device = "cpu" if self.device.type == "mps" else self.device
223
224
225
226
227
        latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
        if latents is None:
            latents = torch.randn(
                latents_shape,
                generator=generator,
228
                device=latents_device,
229
230
231
232
            )
        else:
            if latents.shape != latents_shape:
                raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
233
        latents = latents.to(self.device)
234
235

        # set timesteps
236
        self.scheduler.set_timesteps(num_inference_steps)
Suraj Patil's avatar
Suraj Patil committed
237

238
        # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
239
240
241
        if isinstance(self.scheduler, LMSDiscreteScheduler):
            latents = latents * self.scheduler.sigmas[0]

Suraj Patil's avatar
Suraj Patil committed
242
243
244
245
246
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]
        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
247
        extra_step_kwargs = {}
Suraj Patil's avatar
Suraj Patil committed
248
        if accepts_eta:
249
            extra_step_kwargs["eta"] = eta
Suraj Patil's avatar
Suraj Patil committed
250

hysts's avatar
hysts committed
251
        for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
Suraj Patil's avatar
Suraj Patil committed
252
253
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
254
255
            if isinstance(self.scheduler, LMSDiscreteScheduler):
                sigma = self.scheduler.sigmas[i]
256
                # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
257
                latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
Suraj Patil's avatar
Suraj Patil committed
258
259

            # predict the noise residual
260
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
Suraj Patil's avatar
Suraj Patil committed
261
262
263
264
265
266
267

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
268
            if isinstance(self.scheduler, LMSDiscreteScheduler):
269
                latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
270
            else:
271
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
Suraj Patil's avatar
Suraj Patil committed
272
273
274

        # scale and decode the image latents with vae
        latents = 1 / 0.18215 * latents
275
        image = self.vae.decode(latents).sample
Suraj Patil's avatar
Suraj Patil committed
276
277
278

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.cpu().permute(0, 2, 3, 1).numpy()
Suraj Patil's avatar
Suraj Patil committed
279
280

        # run safety checker
281
282
        safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
        image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
Suraj Patil's avatar
Suraj Patil committed
283

Suraj Patil's avatar
Suraj Patil committed
284
285
286
        if output_type == "pil":
            image = self.numpy_to_pil(image)

287
288
289
290
        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)