regional_prompting_stable_diffusion.py 27.4 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
import math
2
from typing import Dict, Optional
Patrick von Platen's avatar
Patrick von Platen committed
3
4
5

import torch
import torchvision.transforms.functional as FF
6
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
Patrick von Platen's avatar
Patrick von Platen committed
7

8
9
10
11
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
Patrick von Platen's avatar
Patrick von Platen committed
12

13
14
15

try:
    from compel import Compel
Patrick von Platen's avatar
Patrick von Platen committed
16
except ImportError:
17
18
    Compel = None

19
KBASE = "ADDBASE"
20
21
22
KCOMM = "ADDCOMM"
KBRK = "BREAK"

Patrick von Platen's avatar
Patrick von Platen committed
23

24
25
26
27
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
    r"""
    Args for Regional Prompting Pipeline:
        rp_args:dict
Patrick von Platen's avatar
Patrick von Platen committed
28
        Required
29
30
31
            rp_args["mode"]: cols, rows, prompt, prompt-ex
        for cols, rows mode
            rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
Patrick von Platen's avatar
Patrick von Platen committed
32
        for prompt, prompt-ex mode
33
            rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Patrick von Platen's avatar
Patrick von Platen committed
34

35
36
        Optional
            rp_args["save_mask"]: True/False (save masks in prompt mode)
37
38
39
40
41
            rp_args["power"]: int (power for attention maps in prompt mode)
            rp_args["base_ratio"]:
                float (Sets the ratio of the base prompt)
                ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
                [Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67

    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
        feature_extractor ([`CLIPImageProcessor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """
Patrick von Platen's avatar
Patrick von Platen committed
68

69
70
71
72
73
74
75
76
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
77
        feature_extractor: CLIPImageProcessor,
78
        image_encoder: CLIPVisionModelWithProjection = None,
79
80
        requires_safety_checker: bool = True,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
81
        super().__init__(
82
83
84
85
86
87
88
            vae,
            text_encoder,
            tokenizer,
            unet,
            scheduler,
            safety_checker,
            feature_extractor,
89
            image_encoder,
90
            requires_safety_checker,
Patrick von Platen's avatar
Patrick von Platen committed
91
        )
92
93
94
95
96
97
98
99
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
100
            image_encoder=image_encoder,
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt: str,
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: str = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[torch.Generator] = None,
115
        latents: Optional[torch.Tensor] = None,
116
117
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
Patrick von Platen's avatar
Patrick von Platen committed
118
        rp_args: Dict[str, str] = None,
119
    ):
120
        active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
121
        use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
Patrick von Platen's avatar
Patrick von Platen committed
122
        if negative_prompt is None:
123
            negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
124
125
126

        device = self._execution_device
        regions = 0
Patrick von Platen's avatar
Patrick von Platen committed
127

128
        self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
129
130
        self.power = int(rp_args["power"]) if "power" in rp_args else 1

131
        prompts = prompt if isinstance(prompt, list) else [prompt]
132
        n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
133
        self.batch = batch = num_images_per_prompt * len(prompts)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        if use_base:
            bases = prompts.copy()
            n_bases = n_prompts.copy()

            for i, prompt in enumerate(prompts):
                parts = prompt.split(KBASE)
                if len(parts) == 2:
                    bases[i], prompts[i] = parts
                elif len(parts) > 2:
                    raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
            for i, prompt in enumerate(n_prompts):
                n_parts = prompt.split(KBASE)
                if len(n_parts) == 2:
                    n_bases[i], n_prompts[i] = n_parts
                elif len(n_parts) > 2:
                    raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")

            all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
            all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)

Patrick von Platen's avatar
Patrick von Platen committed
155
156
        all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
        all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
157

158
        equal = len(all_prompts_cn) == len(all_n_prompts_cn)
159
160
161

        if Compel:
            compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
Patrick von Platen's avatar
Patrick von Platen committed
162

163
164
165
166
167
            def getcompelembs(prps):
                embl = []
                for prp in prps:
                    embl.append(compel.build_conditioning_tensor(prp))
                return torch.cat(embl)
Patrick von Platen's avatar
Patrick von Platen committed
168

169
            conds = getcompelembs(all_prompts_cn)
170
            unconds = getcompelembs(all_n_prompts_cn)
171
172
173
174
175
176
177
178
179
180
            base_embs = getcompelembs(all_bases_cn) if use_base else None
            base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
            # When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
            embs = getcompelembs(prompts) if not use_base else base_embs
            n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs

            if use_base and self.base_ratio > 0:
                conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
                unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

181
182
183
            prompt = negative_prompt = None
        else:
            conds = self.encode_prompt(prompts, device, 1, True)[0]
Patrick von Platen's avatar
Patrick von Platen committed
184
185
            unconds = (
                self.encode_prompt(n_prompts, device, 1, True)[0]
186
                if equal
Patrick von Platen's avatar
Patrick von Platen committed
187
188
                else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
            )
189
190
191
192
193
194
195
196
197
198
199
200

            if use_base and self.base_ratio > 0:
                base_embs = self.encode_prompt(bases, device, 1, True)[0]
                base_n_embs = (
                    self.encode_prompt(n_bases, device, 1, True)[0]
                    if equal
                    else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
                )

                conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
                unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds

201
202
203
204
205
206
            embs = n_embs = None

        if not active:
            pcallback = None
            mode = None
        else:
Patrick von Platen's avatar
Patrick von Platen committed
207
208
209
210
            if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
                mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
                ocells, icells, regions = make_cells(rp_args["div"])

211
212
213
214
215
216
217
            elif "PRO" in rp_args["mode"].upper():
                regions = len(all_prompts_p[0])
                mode = "PROMPT"
                reset_attnmaps(self)
                self.ex = "EX" in rp_args["mode"].upper()
                self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
                thresholds = [float(x) for x in rp_args["th"].split(",")]
Patrick von Platen's avatar
Patrick von Platen committed
218
219

            orig_hw = (height, width)
220
221
            revers = True

222
            def pcallback(s_self, step: int, timestep: int, latents: torch.Tensor, selfs=None):
223
224
                if "PRO" in mode:  # in Prompt mode, make masks from sum of attension maps
                    self.step = step
Patrick von Platen's avatar
Patrick von Platen committed
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
                    if len(self.attnmaps_sizes) > 3:
                        self.history[step] = self.attnmaps.copy()
                        for hw in self.attnmaps_sizes:
                            allmasks = []
                            basemasks = [None] * batch
                            for tt, th in zip(target_tokens, thresholds):
                                for b in range(batch):
                                    key = f"{tt}-{b}"
                                    _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
                                    mask = mask.unsqueeze(0).unsqueeze(-1)
                                    if self.ex:
                                        allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
                                        allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
                                    allmasks.append(mask)
                                    basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
Patrick von Platen's avatar
Patrick von Platen committed
241
                            basemasks = [1 - mask for mask in basemasks]
242
243
244
245
246
247
248
249
                            basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
                            allmasks = basemasks + allmasks

                            self.attnmasks[hw] = torch.cat(allmasks)
                        self.maskready = True
                return latents

            def hook_forward(module):
Patrick von Platen's avatar
Patrick von Platen committed
250
                # diffusers==0.23.2
251
                def forward(
252
253
254
255
                    hidden_states: torch.Tensor,
                    encoder_hidden_states: Optional[torch.Tensor] = None,
                    attention_mask: Optional[torch.Tensor] = None,
                    temb: Optional[torch.Tensor] = None,
256
257
                    scale: float = 1.0,
                ) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
258
                    attn = module
259
                    xshape = hidden_states.shape
Patrick von Platen's avatar
Patrick von Platen committed
260
                    self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
261
262

                    if revers:
Patrick von Platen's avatar
Patrick von Platen committed
263
                        nx, px = hidden_states.chunk(2)
264
                    else:
Patrick von Platen's avatar
Patrick von Platen committed
265
                        px, nx = hidden_states.chunk(2)
266

267
268
269
270
271
                    if equal:
                        hidden_states = torch.cat(
                            [px for i in range(regions)] + [nx for i in range(regions)],
                            0,
                        )
Patrick von Platen's avatar
Patrick von Platen committed
272
                        encoder_hidden_states = torch.cat([conds] + [unconds])
273
                    else:
Patrick von Platen's avatar
Patrick von Platen committed
274
275
                        hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
                        encoder_hidden_states = torch.cat([conds] + [unconds])
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

                    residual = hidden_states

                    if attn.spatial_norm is not None:
                        hidden_states = attn.spatial_norm(hidden_states, temb)

                    input_ndim = hidden_states.ndim

                    if input_ndim == 4:
                        batch_size, channel, height, width = hidden_states.shape
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

                    batch_size, sequence_length, _ = (
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
                    )

                    if attention_mask is not None:
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

                    if attn.group_norm is not None:
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

299
                    query = attn.to_q(hidden_states)
300
301
302
303
304
305

                    if encoder_hidden_states is None:
                        encoder_hidden_states = hidden_states
                    elif attn.norm_cross:
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

306
307
                    key = attn.to_k(encoder_hidden_states)
                    value = attn.to_v(encoder_hidden_states)
308
309
310
311
312
313
314
315
316
317
318
319

                    inner_dim = key.shape[-1]
                    head_dim = inner_dim // attn.heads

                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
                    # TODO: add support for attn.scale when we move to Torch 2.1
                    hidden_states = scaled_dot_product_attention(
Patrick von Platen's avatar
Patrick von Platen committed
320
321
322
323
324
325
326
327
                        self,
                        query,
                        key,
                        value,
                        attn_mask=attention_mask,
                        dropout_p=0.0,
                        is_causal=False,
                        getattn="PRO" in mode,
328
329
330
331
                    )

                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
                    hidden_states = hidden_states.to(query.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
332

333
                    # linear proj
334
                    hidden_states = attn.to_out[0](hidden_states)
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
                    # dropout
                    hidden_states = attn.to_out[1](hidden_states)

                    if input_ndim == 4:
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

                    if attn.residual_connection:
                        hidden_states = hidden_states + residual

                    hidden_states = hidden_states / attn.rescale_output_factor

                    #### Regional Prompting Col/Row mode
                    if any(x in mode for x in ["COL", "ROW"]):
                        reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
                        center = reshaped.shape[0] // 2
350
351
352
                        px = reshaped[0:center] if equal else reshaped[0:-batch]
                        nx = reshaped[center:] if equal else reshaped[-batch:]
                        outs = [px, nx] if equal else [px]
353
354
                        for out in outs:
                            c = 0
Patrick von Platen's avatar
Patrick von Platen committed
355
                            for i, ocell in enumerate(ocells):
356
357
                                for icell in icells[i]:
                                    if "ROW" in mode:
Patrick von Platen's avatar
Patrick von Platen committed
358
359
360
361
362
363
364
365
366
367
368
                                        out[
                                            0:batch,
                                            int(h * ocell[0]) : int(h * ocell[1]),
                                            int(w * icell[0]) : int(w * icell[1]),
                                            :,
                                        ] = out[
                                            c * batch : (c + 1) * batch,
                                            int(h * ocell[0]) : int(h * ocell[1]),
                                            int(w * icell[0]) : int(w * icell[1]),
                                            :,
                                        ]
369
                                    else:
Patrick von Platen's avatar
Patrick von Platen committed
370
371
372
373
374
375
376
377
378
379
380
                                        out[
                                            0:batch,
                                            int(h * icell[0]) : int(h * icell[1]),
                                            int(w * ocell[0]) : int(w * ocell[1]),
                                            :,
                                        ] = out[
                                            c * batch : (c + 1) * batch,
                                            int(h * icell[0]) : int(h * icell[1]),
                                            int(w * ocell[0]) : int(w * ocell[1]),
                                            :,
                                        ]
381
                                    c += 1
382
                        px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
Patrick von Platen's avatar
Patrick von Platen committed
383
                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
384
385
386
387
                        hidden_states = hidden_states.reshape(xshape)

                    #### Regional Prompting Prompt mode
                    elif "PRO" in mode:
388
389
390
391
                        px, nx = (
                            torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
                            hidden_states[-batch:],
                        )
Patrick von Platen's avatar
Patrick von Platen committed
392
393
394

                        if (h, w) in self.attnmasks and self.maskready:

395
                            def mask(input):
Patrick von Platen's avatar
Patrick von Platen committed
396
                                out = torch.multiply(input, self.attnmasks[(h, w)])
397
398
399
400
                                for b in range(batch):
                                    for r in range(1, regions):
                                        out[b] = out[b] + out[r * batch + b]
                                return out
Patrick von Platen's avatar
Patrick von Platen committed
401

402
403
                            px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
                        px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
Patrick von Platen's avatar
Patrick von Platen committed
404
                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                    return hidden_states

                return forward

            def hook_forwards(root_module: torch.nn.Module):
                for name, module in root_module.named_modules():
                    if "attn2" in name and module.__class__.__name__ == "Attention":
                        module.forward = hook_forward(module)

            hook_forwards(self.unet)

        output = StableDiffusionPipeline(**self.components)(
            prompt=prompt,
            prompt_embeds=embs,
            negative_prompt=negative_prompt,
            negative_prompt_embeds=n_embs,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
            eta=eta,
            generator=generator,
            latents=latents,
            output_type=output_type,
            return_dict=return_dict,
Patrick von Platen's avatar
Patrick von Platen committed
431
            callback_on_step_end=pcallback,
432
433
434
435
436
437
438
        )

        if "save_mask" in rp_args:
            save_mask = rp_args["save_mask"]
        else:
            save_mask = False

Patrick von Platen's avatar
Patrick von Platen committed
439
        if mode == "PROMPT" and save_mask:
440
441
442
443
444
445
446
447
448
            saveattnmaps(
                self,
                output,
                height,
                width,
                thresholds,
                num_inference_steps // 2,
                regions,
            )
449
450
451
452
453

        return output


### Make prompt list for each regions
Patrick von Platen's avatar
Patrick von Platen committed
454
def promptsmaker(prompts, batch):
455
456
457
458
459
460
    out_p = []
    plen = len(prompts)
    for prompt in prompts:
        add = ""
        if KCOMM in prompt:
            add, prompt = prompt.split(KCOMM)
461
462
463
            add = add.strip() + " "
        prompts = [p.strip() for p in prompt.split(KBRK)]
        out_p.append([add + p for i, p in enumerate(prompts)])
Patrick von Platen's avatar
Patrick von Platen committed
464
465
466
    out = [None] * batch * len(out_p[0]) * len(out_p)
    for p, prs in enumerate(out_p):  # inputs prompts
        for r, pr in enumerate(prs):  # prompts for regions
467
            start = (p + r * plen) * batch
Patrick von Platen's avatar
Patrick von Platen committed
468
            out[start : start + batch] = [pr] * batch  # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
469
470
    return out, out_p

Patrick von Platen's avatar
Patrick von Platen committed
471

472
473
474
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
Patrick von Platen's avatar
Patrick von Platen committed
475
476
    if ";" not in ratios and "," in ratios:
        ratios = ratios.replace(",", ";")
477
478
479
480
481
482
    ratios = ratios.split(";")
    ratios = [inratios.split(",") for inratios in ratios]

    icells = []
    ocells = []

Patrick von Platen's avatar
Patrick von Platen committed
483
    def startend(cells, array):
484
485
486
487
488
489
490
        current_start = 0
        array = [float(x) for x in array]
        for value in array:
            end = current_start + (value / sum(array))
            cells.append([current_start, end])
            current_start = end

Patrick von Platen's avatar
Patrick von Platen committed
491
    startend(ocells, [r[0] for r in ratios])
492
493
494

    for inratios in ratios:
        if 2 > len(inratios):
Patrick von Platen's avatar
Patrick von Platen committed
495
            icells.append([[0, 1]])
496
497
        else:
            add = []
Patrick von Platen's avatar
Patrick von Platen committed
498
            startend(add, inratios[1:])
499
500
501
            icells.append(add)
    return ocells, icells, sum(len(cell) for cell in icells)

Patrick von Platen's avatar
Patrick von Platen committed
502

503
504
def make_emblist(self, prompts):
    with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
505
        tokens = self.tokenizer(
506
507
508
509
510
            prompts,
            max_length=self.tokenizer.model_max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
Patrick von Platen's avatar
Patrick von Platen committed
511
512
        ).input_ids.to(self.device)
        embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
513
    return embs
Patrick von Platen's avatar
Patrick von Platen committed
514

515
516

def split_dims(xs, height, width):
Patrick von Platen's avatar
Patrick von Platen committed
517
    def repeat_div(x, y):
518
519
520
521
        while y > 0:
            x = math.ceil(x / 2)
            y = y - 1
        return x
Patrick von Platen's avatar
Patrick von Platen committed
522

523
    scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
Patrick von Platen's avatar
Patrick von Platen committed
524
525
526
527
    dsh = repeat_div(height, scale)
    dsw = repeat_div(width, scale)
    return dsh, dsw

528
529

##### for prompt mode
Patrick von Platen's avatar
Patrick von Platen committed
530
531
def get_attn_maps(self, attn):
    height, width = self.hw
532
    target_tokens = self.target_tokens
Patrick von Platen's avatar
Patrick von Platen committed
533
534
535
    if (height, width) not in self.attnmaps_sizes:
        self.attnmaps_sizes.append((height, width))

536
537
538
    for b in range(self.batch):
        for t in target_tokens:
            power = self.power
Patrick von Platen's avatar
Patrick von Platen committed
539
540
541
            add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
            add = torch.sum(add, dim=2)
            key = f"{t}-{b}"
542
543
544
545
            if key not in self.attnmaps:
                self.attnmaps[key] = add
            else:
                if self.attnmaps[key].shape[1] != add.shape[1]:
Patrick von Platen's avatar
Patrick von Platen committed
546
547
                    add = add.view(8, height, width)
                    add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
548
549
550
551
                    add = add.reshape_as(self.attnmaps[key])

                self.attnmaps[key] = self.attnmaps[key] + add

Patrick von Platen's avatar
Patrick von Platen committed
552
553

def reset_attnmaps(self):  # init parameters in every batch
554
    self.step = 0
Patrick von Platen's avatar
Patrick von Platen committed
555
556
557
    self.attnmaps = {}  # maked from attention maps
    self.attnmaps_sizes = []  # height,width set of u-net blocks
    self.attnmasks = {}  # maked from attnmaps for regions
558
559
560
    self.maskready = False
    self.history = {}

Patrick von Platen's avatar
Patrick von Platen committed
561
562

def saveattnmaps(self, output, h, w, th, step, regions):
563
564
    masks = []
    for i, mask in enumerate(self.history[step].values()):
Patrick von Platen's avatar
Patrick von Platen committed
565
        img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
566
567
568
569
570
571
572
573
574
        if self.ex:
            masks = [x - mask for x in masks]
            masks.append(mask)
            if len(masks) == regions - 1:
                output.images.extend([FF.to_pil_image(mask) for mask in masks])
                masks = []
        else:
            output.images.append(img)

Patrick von Platen's avatar
Patrick von Platen committed
575
576
577
578

def makepmask(
    self, mask, h, w, th, step
):  # make masks from attention cache return [for preview, for attention, for Latent]
579
    th = th - step * 0.005
Patrick von Platen's avatar
Patrick von Platen committed
580
581
582
    if 0.05 >= th:
        th = 0.05
    mask = torch.mean(mask, dim=0)
583
    mask = mask / mask.max().item()
Patrick von Platen's avatar
Patrick von Platen committed
584
    mask = torch.where(mask > th, 1, 0)
585
    mask = mask.float()
Patrick von Platen's avatar
Patrick von Platen committed
586
    mask = mask.view(1, *self.attnmaps_sizes[0])
587
    img = FF.to_pil_image(mask)
Patrick von Platen's avatar
Patrick von Platen committed
588
589
    img = img.resize((w, h))
    mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
590
    lmask = mask
Patrick von Platen's avatar
Patrick von Platen committed
591
592
    mask = mask.reshape(h * w)
    mask = torch.where(mask > 0.1, 1, 0)
593
594
    return img, mask, lmask

Patrick von Platen's avatar
Patrick von Platen committed
595

596
597
def tokendealer(self, all_prompts):
    for prompts in all_prompts:
Patrick von Platen's avatar
Patrick von Platen committed
598
        targets = [p.split(",")[-1] for p in prompts[1:]]
599
600
601
        tt = []

        for target in targets:
Patrick von Platen's avatar
Patrick von Platen committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            ptokens = (
                self.tokenizer(
                    prompts,
                    max_length=self.tokenizer.model_max_length,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids
            )[0]
            ttokens = (
                self.tokenizer(
                    target,
                    max_length=self.tokenizer.model_max_length,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids
            )[0]
620
621
622

            tlist = []

Patrick von Platen's avatar
Patrick von Platen committed
623
            for t in range(ttokens.shape[0] - 2):
624
625
626
                for p in range(ptokens.shape[0]):
                    if ttokens[t + 1] == ptokens[p]:
                        tlist.append(p)
Patrick von Platen's avatar
Patrick von Platen committed
627
628
            if tlist != []:
                tt.append(tlist)
629
630
631

    return tt

Patrick von Platen's avatar
Patrick von Platen committed
632
633

def scaled_dot_product_attention(
634
635
636
637
638
639
640
641
642
    self,
    query,
    key,
    value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False,
    scale=None,
    getattn=False,
Patrick von Platen's avatar
Patrick von Platen committed
643
) -> torch.Tensor:
644
645
646
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
Patrick von Platen's avatar
Patrick von Platen committed
647
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
648
649
650
651
652
653
654
655
656
657
658
659
660
661
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
Patrick von Platen's avatar
Patrick von Platen committed
662
663
    if getattn:
        get_attn_maps(self, attn_weight)
664
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
Patrick von Platen's avatar
Patrick von Platen committed
665
    return attn_weight @ value