regional_prompting_stable_diffusion.py 24.9 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
import math
2
from typing import Dict, Optional
Patrick von Platen's avatar
Patrick von Platen committed
3
4
5

import torch
import torchvision.transforms.functional as FF
6
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
Patrick von Platen's avatar
Patrick von Platen committed
7

8
9
10
11
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
Patrick von Platen's avatar
Patrick von Platen committed
12
13
from diffusers.utils import USE_PEFT_BACKEND

14
15
16

try:
    from compel import Compel
Patrick von Platen's avatar
Patrick von Platen committed
17
except ImportError:
18
19
20
21
22
    Compel = None

KCOMM = "ADDCOMM"
KBRK = "BREAK"

Patrick von Platen's avatar
Patrick von Platen committed
23

24
25
26
27
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
    r"""
    Args for Regional Prompting Pipeline:
        rp_args:dict
Patrick von Platen's avatar
Patrick von Platen committed
28
        Required
29
30
31
            rp_args["mode"]: cols, rows, prompt, prompt-ex
        for cols, rows mode
            rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
Patrick von Platen's avatar
Patrick von Platen committed
32
        for prompt, prompt-ex mode
33
            rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Patrick von Platen's avatar
Patrick von Platen committed
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        Optional
            rp_args["save_mask"]: True/False (save masks in prompt mode)

    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
        feature_extractor ([`CLIPImageProcessor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """
Patrick von Platen's avatar
Patrick von Platen committed
63

64
65
66
67
68
69
70
71
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
72
        feature_extractor: CLIPImageProcessor,
73
74
        requires_safety_checker: bool = True,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
75
        super().__init__(
76
77
78
79
80
81
82
83
            vae,
            text_encoder,
            tokenizer,
            unet,
            scheduler,
            safety_checker,
            feature_extractor,
            requires_safety_checker,
Patrick von Platen's avatar
Patrick von Platen committed
84
        )
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt: str,
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: str = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[torch.Generator] = None,
107
        latents: Optional[torch.Tensor] = None,
108
109
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
Patrick von Platen's avatar
Patrick von Platen committed
110
        rp_args: Dict[str, str] = None,
111
    ):
112
        active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
Patrick von Platen's avatar
Patrick von Platen committed
113
        if negative_prompt is None:
114
            negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
115
116
117

        device = self._execution_device
        regions = 0
Patrick von Platen's avatar
Patrick von Platen committed
118

119
120
        self.power = int(rp_args["power"]) if "power" in rp_args else 1

121
122
        prompts = prompt if isinstance(prompt, list) else [prompt]
        n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
123
        self.batch = batch = num_images_per_prompt * len(prompts)
Patrick von Platen's avatar
Patrick von Platen committed
124
125
        all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
        all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
126

127
        equal = len(all_prompts_cn) == len(all_n_prompts_cn)
128
129
130

        if Compel:
            compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
Patrick von Platen's avatar
Patrick von Platen committed
131

132
133
134
135
136
            def getcompelembs(prps):
                embl = []
                for prp in prps:
                    embl.append(compel.build_conditioning_tensor(prp))
                return torch.cat(embl)
Patrick von Platen's avatar
Patrick von Platen committed
137

138
            conds = getcompelembs(all_prompts_cn)
139
            unconds = getcompelembs(all_n_prompts_cn)
140
141
142
143
144
            embs = getcompelembs(prompts)
            n_embs = getcompelembs(n_prompts)
            prompt = negative_prompt = None
        else:
            conds = self.encode_prompt(prompts, device, 1, True)[0]
Patrick von Platen's avatar
Patrick von Platen committed
145
146
            unconds = (
                self.encode_prompt(n_prompts, device, 1, True)[0]
147
                if equal
Patrick von Platen's avatar
Patrick von Platen committed
148
149
                else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
            )
150
151
152
153
154
155
            embs = n_embs = None

        if not active:
            pcallback = None
            mode = None
        else:
Patrick von Platen's avatar
Patrick von Platen committed
156
157
158
159
            if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
                mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
                ocells, icells, regions = make_cells(rp_args["div"])

160
161
162
163
164
165
166
            elif "PRO" in rp_args["mode"].upper():
                regions = len(all_prompts_p[0])
                mode = "PROMPT"
                reset_attnmaps(self)
                self.ex = "EX" in rp_args["mode"].upper()
                self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
                thresholds = [float(x) for x in rp_args["th"].split(",")]
Patrick von Platen's avatar
Patrick von Platen committed
167
168

            orig_hw = (height, width)
169
170
            revers = True

171
            def pcallback(s_self, step: int, timestep: int, latents: torch.Tensor, selfs=None):
172
173
                if "PRO" in mode:  # in Prompt mode, make masks from sum of attension maps
                    self.step = step
Patrick von Platen's avatar
Patrick von Platen committed
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
                    if len(self.attnmaps_sizes) > 3:
                        self.history[step] = self.attnmaps.copy()
                        for hw in self.attnmaps_sizes:
                            allmasks = []
                            basemasks = [None] * batch
                            for tt, th in zip(target_tokens, thresholds):
                                for b in range(batch):
                                    key = f"{tt}-{b}"
                                    _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
                                    mask = mask.unsqueeze(0).unsqueeze(-1)
                                    if self.ex:
                                        allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
                                        allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
                                    allmasks.append(mask)
                                    basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
Patrick von Platen's avatar
Patrick von Platen committed
190
                            basemasks = [1 - mask for mask in basemasks]
191
192
193
194
195
196
197
198
                            basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
                            allmasks = basemasks + allmasks

                            self.attnmasks[hw] = torch.cat(allmasks)
                        self.maskready = True
                return latents

            def hook_forward(module):
Patrick von Platen's avatar
Patrick von Platen committed
199
                # diffusers==0.23.2
200
                def forward(
201
202
203
204
                    hidden_states: torch.Tensor,
                    encoder_hidden_states: Optional[torch.Tensor] = None,
                    attention_mask: Optional[torch.Tensor] = None,
                    temb: Optional[torch.Tensor] = None,
205
206
                    scale: float = 1.0,
                ) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
207
                    attn = module
208
                    xshape = hidden_states.shape
Patrick von Platen's avatar
Patrick von Platen committed
209
                    self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
210
211

                    if revers:
Patrick von Platen's avatar
Patrick von Platen committed
212
                        nx, px = hidden_states.chunk(2)
213
                    else:
Patrick von Platen's avatar
Patrick von Platen committed
214
                        px, nx = hidden_states.chunk(2)
215

216
217
218
219
220
                    if equal:
                        hidden_states = torch.cat(
                            [px for i in range(regions)] + [nx for i in range(regions)],
                            0,
                        )
Patrick von Platen's avatar
Patrick von Platen committed
221
                        encoder_hidden_states = torch.cat([conds] + [unconds])
222
                    else:
Patrick von Platen's avatar
Patrick von Platen committed
223
224
                        hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
                        encoder_hidden_states = torch.cat([conds] + [unconds])
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

                    residual = hidden_states

                    args = () if USE_PEFT_BACKEND else (scale,)

                    if attn.spatial_norm is not None:
                        hidden_states = attn.spatial_norm(hidden_states, temb)

                    input_ndim = hidden_states.ndim

                    if input_ndim == 4:
                        batch_size, channel, height, width = hidden_states.shape
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

                    batch_size, sequence_length, _ = (
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
                    )

                    if attention_mask is not None:
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

                    if attn.group_norm is not None:
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

                    args = () if USE_PEFT_BACKEND else (scale,)
                    query = attn.to_q(hidden_states, *args)

                    if encoder_hidden_states is None:
                        encoder_hidden_states = hidden_states
                    elif attn.norm_cross:
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

                    key = attn.to_k(encoder_hidden_states, *args)
                    value = attn.to_v(encoder_hidden_states, *args)

                    inner_dim = key.shape[-1]
                    head_dim = inner_dim // attn.heads

                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
                    # TODO: add support for attn.scale when we move to Torch 2.1
                    hidden_states = scaled_dot_product_attention(
Patrick von Platen's avatar
Patrick von Platen committed
272
273
274
275
276
277
278
279
                        self,
                        query,
                        key,
                        value,
                        attn_mask=attention_mask,
                        dropout_p=0.0,
                        is_causal=False,
                        getattn="PRO" in mode,
280
281
282
283
                    )

                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
                    hidden_states = hidden_states.to(query.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                    # linear proj
                    hidden_states = attn.to_out[0](hidden_states, *args)
                    # dropout
                    hidden_states = attn.to_out[1](hidden_states)

                    if input_ndim == 4:
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

                    if attn.residual_connection:
                        hidden_states = hidden_states + residual

                    hidden_states = hidden_states / attn.rescale_output_factor

                    #### Regional Prompting Col/Row mode
                    if any(x in mode for x in ["COL", "ROW"]):
                        reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
                        center = reshaped.shape[0] // 2
302
303
304
                        px = reshaped[0:center] if equal else reshaped[0:-batch]
                        nx = reshaped[center:] if equal else reshaped[-batch:]
                        outs = [px, nx] if equal else [px]
305
306
                        for out in outs:
                            c = 0
Patrick von Platen's avatar
Patrick von Platen committed
307
                            for i, ocell in enumerate(ocells):
308
309
                                for icell in icells[i]:
                                    if "ROW" in mode:
Patrick von Platen's avatar
Patrick von Platen committed
310
311
312
313
314
315
316
317
318
319
320
                                        out[
                                            0:batch,
                                            int(h * ocell[0]) : int(h * ocell[1]),
                                            int(w * icell[0]) : int(w * icell[1]),
                                            :,
                                        ] = out[
                                            c * batch : (c + 1) * batch,
                                            int(h * ocell[0]) : int(h * ocell[1]),
                                            int(w * icell[0]) : int(w * icell[1]),
                                            :,
                                        ]
321
                                    else:
Patrick von Platen's avatar
Patrick von Platen committed
322
323
324
325
326
327
328
329
330
331
332
                                        out[
                                            0:batch,
                                            int(h * icell[0]) : int(h * icell[1]),
                                            int(w * ocell[0]) : int(w * ocell[1]),
                                            :,
                                        ] = out[
                                            c * batch : (c + 1) * batch,
                                            int(h * icell[0]) : int(h * icell[1]),
                                            int(w * ocell[0]) : int(w * ocell[1]),
                                            :,
                                        ]
333
                                    c += 1
334
                        px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
Patrick von Platen's avatar
Patrick von Platen committed
335
                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
336
337
338
339
                        hidden_states = hidden_states.reshape(xshape)

                    #### Regional Prompting Prompt mode
                    elif "PRO" in mode:
340
341
342
343
                        px, nx = (
                            torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
                            hidden_states[-batch:],
                        )
Patrick von Platen's avatar
Patrick von Platen committed
344
345
346

                        if (h, w) in self.attnmasks and self.maskready:

347
                            def mask(input):
Patrick von Platen's avatar
Patrick von Platen committed
348
                                out = torch.multiply(input, self.attnmasks[(h, w)])
349
350
351
352
                                for b in range(batch):
                                    for r in range(1, regions):
                                        out[b] = out[b] + out[r * batch + b]
                                return out
Patrick von Platen's avatar
Patrick von Platen committed
353

354
355
                            px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
                        px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
Patrick von Platen's avatar
Patrick von Platen committed
356
                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
                    return hidden_states

                return forward

            def hook_forwards(root_module: torch.nn.Module):
                for name, module in root_module.named_modules():
                    if "attn2" in name and module.__class__.__name__ == "Attention":
                        module.forward = hook_forward(module)

            hook_forwards(self.unet)

        output = StableDiffusionPipeline(**self.components)(
            prompt=prompt,
            prompt_embeds=embs,
            negative_prompt=negative_prompt,
            negative_prompt_embeds=n_embs,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
            eta=eta,
            generator=generator,
            latents=latents,
            output_type=output_type,
            return_dict=return_dict,
Patrick von Platen's avatar
Patrick von Platen committed
383
            callback_on_step_end=pcallback,
384
385
386
387
388
389
390
        )

        if "save_mask" in rp_args:
            save_mask = rp_args["save_mask"]
        else:
            save_mask = False

Patrick von Platen's avatar
Patrick von Platen committed
391
        if mode == "PROMPT" and save_mask:
392
393
394
395
396
397
398
399
400
            saveattnmaps(
                self,
                output,
                height,
                width,
                thresholds,
                num_inference_steps // 2,
                regions,
            )
401
402
403
404
405

        return output


### Make prompt list for each regions
Patrick von Platen's avatar
Patrick von Platen committed
406
def promptsmaker(prompts, batch):
407
408
409
410
411
412
413
414
415
    out_p = []
    plen = len(prompts)
    for prompt in prompts:
        add = ""
        if KCOMM in prompt:
            add, prompt = prompt.split(KCOMM)
            add = add + " "
        prompts = prompt.split(KBRK)
        out_p.append([add + p for p in prompts])
Patrick von Platen's avatar
Patrick von Platen committed
416
417
418
    out = [None] * batch * len(out_p[0]) * len(out_p)
    for p, prs in enumerate(out_p):  # inputs prompts
        for r, pr in enumerate(prs):  # prompts for regions
419
            start = (p + r * plen) * batch
Patrick von Platen's avatar
Patrick von Platen committed
420
            out[start : start + batch] = [pr] * batch  # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
421
422
    return out, out_p

Patrick von Platen's avatar
Patrick von Platen committed
423

424
425
426
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
Patrick von Platen's avatar
Patrick von Platen committed
427
428
    if ";" not in ratios and "," in ratios:
        ratios = ratios.replace(",", ";")
429
430
431
432
433
434
    ratios = ratios.split(";")
    ratios = [inratios.split(",") for inratios in ratios]

    icells = []
    ocells = []

Patrick von Platen's avatar
Patrick von Platen committed
435
    def startend(cells, array):
436
437
438
439
440
441
442
        current_start = 0
        array = [float(x) for x in array]
        for value in array:
            end = current_start + (value / sum(array))
            cells.append([current_start, end])
            current_start = end

Patrick von Platen's avatar
Patrick von Platen committed
443
    startend(ocells, [r[0] for r in ratios])
444
445
446

    for inratios in ratios:
        if 2 > len(inratios):
Patrick von Platen's avatar
Patrick von Platen committed
447
            icells.append([[0, 1]])
448
449
        else:
            add = []
Patrick von Platen's avatar
Patrick von Platen committed
450
            startend(add, inratios[1:])
451
452
453
454
            icells.append(add)

    return ocells, icells, sum(len(cell) for cell in icells)

Patrick von Platen's avatar
Patrick von Platen committed
455

456
457
def make_emblist(self, prompts):
    with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
458
        tokens = self.tokenizer(
459
460
461
462
463
            prompts,
            max_length=self.tokenizer.model_max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
Patrick von Platen's avatar
Patrick von Platen committed
464
465
        ).input_ids.to(self.device)
        embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
466
    return embs
Patrick von Platen's avatar
Patrick von Platen committed
467

468
469

def split_dims(xs, height, width):
Patrick von Platen's avatar
Patrick von Platen committed
470
    def repeat_div(x, y):
471
472
473
474
        while y > 0:
            x = math.ceil(x / 2)
            y = y - 1
        return x
Patrick von Platen's avatar
Patrick von Platen committed
475

476
    scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
Patrick von Platen's avatar
Patrick von Platen committed
477
478
479
480
    dsh = repeat_div(height, scale)
    dsw = repeat_div(width, scale)
    return dsh, dsw

481
482

##### for prompt mode
Patrick von Platen's avatar
Patrick von Platen committed
483
484
def get_attn_maps(self, attn):
    height, width = self.hw
485
    target_tokens = self.target_tokens
Patrick von Platen's avatar
Patrick von Platen committed
486
487
488
    if (height, width) not in self.attnmaps_sizes:
        self.attnmaps_sizes.append((height, width))

489
490
491
    for b in range(self.batch):
        for t in target_tokens:
            power = self.power
Patrick von Platen's avatar
Patrick von Platen committed
492
493
494
            add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
            add = torch.sum(add, dim=2)
            key = f"{t}-{b}"
495
496
497
498
            if key not in self.attnmaps:
                self.attnmaps[key] = add
            else:
                if self.attnmaps[key].shape[1] != add.shape[1]:
Patrick von Platen's avatar
Patrick von Platen committed
499
500
                    add = add.view(8, height, width)
                    add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
501
502
503
504
                    add = add.reshape_as(self.attnmaps[key])

                self.attnmaps[key] = self.attnmaps[key] + add

Patrick von Platen's avatar
Patrick von Platen committed
505
506

def reset_attnmaps(self):  # init parameters in every batch
507
    self.step = 0
Patrick von Platen's avatar
Patrick von Platen committed
508
509
510
    self.attnmaps = {}  # maked from attention maps
    self.attnmaps_sizes = []  # height,width set of u-net blocks
    self.attnmasks = {}  # maked from attnmaps for regions
511
512
513
    self.maskready = False
    self.history = {}

Patrick von Platen's avatar
Patrick von Platen committed
514
515

def saveattnmaps(self, output, h, w, th, step, regions):
516
517
    masks = []
    for i, mask in enumerate(self.history[step].values()):
Patrick von Platen's avatar
Patrick von Platen committed
518
        img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
519
520
521
522
523
524
525
526
527
        if self.ex:
            masks = [x - mask for x in masks]
            masks.append(mask)
            if len(masks) == regions - 1:
                output.images.extend([FF.to_pil_image(mask) for mask in masks])
                masks = []
        else:
            output.images.append(img)

Patrick von Platen's avatar
Patrick von Platen committed
528
529
530
531

def makepmask(
    self, mask, h, w, th, step
):  # make masks from attention cache return [for preview, for attention, for Latent]
532
    th = th - step * 0.005
Patrick von Platen's avatar
Patrick von Platen committed
533
534
535
    if 0.05 >= th:
        th = 0.05
    mask = torch.mean(mask, dim=0)
536
    mask = mask / mask.max().item()
Patrick von Platen's avatar
Patrick von Platen committed
537
    mask = torch.where(mask > th, 1, 0)
538
    mask = mask.float()
Patrick von Platen's avatar
Patrick von Platen committed
539
    mask = mask.view(1, *self.attnmaps_sizes[0])
540
    img = FF.to_pil_image(mask)
Patrick von Platen's avatar
Patrick von Platen committed
541
542
    img = img.resize((w, h))
    mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
543
    lmask = mask
Patrick von Platen's avatar
Patrick von Platen committed
544
545
    mask = mask.reshape(h * w)
    mask = torch.where(mask > 0.1, 1, 0)
546
547
    return img, mask, lmask

Patrick von Platen's avatar
Patrick von Platen committed
548

549
550
def tokendealer(self, all_prompts):
    for prompts in all_prompts:
Patrick von Platen's avatar
Patrick von Platen committed
551
        targets = [p.split(",")[-1] for p in prompts[1:]]
552
553
554
        tt = []

        for target in targets:
Patrick von Platen's avatar
Patrick von Platen committed
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
            ptokens = (
                self.tokenizer(
                    prompts,
                    max_length=self.tokenizer.model_max_length,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids
            )[0]
            ttokens = (
                self.tokenizer(
                    target,
                    max_length=self.tokenizer.model_max_length,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids
            )[0]
573
574
575

            tlist = []

Patrick von Platen's avatar
Patrick von Platen committed
576
            for t in range(ttokens.shape[0] - 2):
577
578
579
                for p in range(ptokens.shape[0]):
                    if ttokens[t + 1] == ptokens[p]:
                        tlist.append(p)
Patrick von Platen's avatar
Patrick von Platen committed
580
581
            if tlist != []:
                tt.append(tlist)
582
583
584

    return tt

Patrick von Platen's avatar
Patrick von Platen committed
585
586

def scaled_dot_product_attention(
587
588
589
590
591
592
593
594
595
    self,
    query,
    key,
    value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False,
    scale=None,
    getattn=False,
Patrick von Platen's avatar
Patrick von Platen committed
596
) -> torch.Tensor:
597
598
599
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
Patrick von Platen's avatar
Patrick von Platen committed
600
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
601
602
603
604
605
606
607
608
609
610
611
612
613
614
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
Patrick von Platen's avatar
Patrick von Platen committed
615
616
    if getattn:
        get_attn_maps(self, attn_weight)
617
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
Patrick von Platen's avatar
Patrick von Platen committed
618
    return attn_weight @ value