"docs/vscode:/vscode.git/clone" did not exist on "dacd3fd4304c727d1ad3625d8bb0fdcc9390c5d6"
stable_diffusion_mega.py 8.67 KB
Newer Older
1
2
from typing import Any, Callable, Dict, List, Optional, Union

3
import PIL.Image
4
import torch
5
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
6
7
8
9
10
11
12
13

from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    DiffusionPipeline,
    LMSDiscreteScheduler,
    PNDMScheduler,
    StableDiffusionImg2ImgPipeline,
14
    StableDiffusionInpaintPipelineLegacy,
15
16
17
18
    StableDiffusionPipeline,
    UNet2DConditionModel,
)
from diffusers.configuration_utils import FrozenDict
19
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
20
21
22
23
24
25
26
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.utils import deprecate, logging


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


27
class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
46
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
47
48
49
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionMegaSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
apolinario's avatar
apolinario committed
50
            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
51
        feature_extractor ([`CLIPImageProcessor`]):
52
53
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """
54

55
    _optional_components = ["safety_checker", "feature_extractor"]
56
57
58
59
60
61
62
63
64

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
        safety_checker: StableDiffusionSafetyChecker,
65
        feature_extractor: CLIPImageProcessor,
66
        requires_safety_checker: bool = True,
67
68
    ):
        super().__init__()
69
        if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
92
        self.register_to_config(requires_safety_checker=requires_safety_checker)
93
94
95
96
97
98
99
100
101

    @property
    def components(self) -> Dict[str, Any]:
        return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}

    @torch.no_grad()
    def inpaint(
        self,
        prompt: Union[str, List[str]],
102
103
        image: Union[torch.Tensor, PIL.Image.Image],
        mask_image: Union[torch.Tensor, PIL.Image.Image],
104
105
106
107
108
109
110
111
112
        strength: float = 0.8,
        num_inference_steps: Optional[int] = 50,
        guidance_scale: Optional[float] = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: Optional[float] = 0.0,
        generator: Optional[torch.Generator] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
113
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
114
        callback_steps: int = 1,
115
116
    ):
        # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
117
        return StableDiffusionInpaintPipelineLegacy(**self.components)(
118
            prompt=prompt,
119
            image=image,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            mask_image=mask_image,
            strength=strength,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_images_per_prompt,
            eta=eta,
            generator=generator,
            output_type=output_type,
            return_dict=return_dict,
            callback=callback,
        )

    @torch.no_grad()
    def img2img(
        self,
        prompt: Union[str, List[str]],
137
        image: Union[torch.Tensor, PIL.Image.Image],
138
139
140
141
142
143
144
145
146
        strength: float = 0.8,
        num_inference_steps: Optional[int] = 50,
        guidance_scale: Optional[float] = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: Optional[float] = 0.0,
        generator: Optional[torch.Generator] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
147
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
148
        callback_steps: int = 1,
149
150
151
152
153
        **kwargs,
    ):
        # For more information on how this function works, please see: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionImg2ImgPipeline
        return StableDiffusionImg2ImgPipeline(**self.components)(
            prompt=prompt,
154
            image=image,
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
            strength=strength,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_images_per_prompt,
            eta=eta,
            generator=generator,
            output_type=output_type,
            return_dict=return_dict,
            callback=callback,
            callback_steps=callback_steps,
        )

    @torch.no_grad()
    def text2img(
        self,
        prompt: Union[str, List[str]],
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[torch.Generator] = None,
180
        latents: Optional[torch.Tensor] = None,
181
182
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
183
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
184
        callback_steps: int = 1,
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    ):
        # For more information on how this function https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline
        return StableDiffusionPipeline(**self.components)(
            prompt=prompt,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            negative_prompt=negative_prompt,
            num_images_per_prompt=num_images_per_prompt,
            eta=eta,
            generator=generator,
            latents=latents,
            output_type=output_type,
            return_dict=return_dict,
            callback=callback,
            callback_steps=callback_steps,
        )