regional_prompting_stable_diffusion.py 24.6 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
import math
2
from typing import Dict, Optional
Patrick von Platen's avatar
Patrick von Platen committed
3
4
5
6
7

import torch
import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

8
9
10
11
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
Patrick von Platen's avatar
Patrick von Platen committed
12
13
from diffusers.utils import USE_PEFT_BACKEND

14
15
16

try:
    from compel import Compel
Patrick von Platen's avatar
Patrick von Platen committed
17
except ImportError:
18
19
20
21
22
    Compel = None

KCOMM = "ADDCOMM"
KBRK = "BREAK"

Patrick von Platen's avatar
Patrick von Platen committed
23

24
25
26
27
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
    r"""
    Args for Regional Prompting Pipeline:
        rp_args:dict
Patrick von Platen's avatar
Patrick von Platen committed
28
        Required
29
30
31
            rp_args["mode"]: cols, rows, prompt, prompt-ex
        for cols, rows mode
            rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
Patrick von Platen's avatar
Patrick von Platen committed
32
        for prompt, prompt-ex mode
33
            rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Patrick von Platen's avatar
Patrick von Platen committed
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        Optional
            rp_args["save_mask"]: True/False (save masks in prompt mode)

    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
        feature_extractor ([`CLIPImageProcessor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """
Patrick von Platen's avatar
Patrick von Platen committed
63

64
65
66
67
68
69
70
71
72
73
74
    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPFeatureExtractor,
        requires_safety_checker: bool = True,
    ):
Patrick von Platen's avatar
Patrick von Platen committed
75
76
77
        super().__init__(
            vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
        )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )

    @torch.no_grad()
    def __call__(
        self,
        prompt: str,
        height: int = 512,
        width: int = 512,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: str = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[torch.Generator] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
Patrick von Platen's avatar
Patrick von Platen committed
103
        rp_args: Dict[str, str] = None,
104
    ):
Patrick von Platen's avatar
Patrick von Platen committed
105
106
107
        active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt  # noqa: E721
        if negative_prompt is None:
            negative_prompt = "" if type(prompt) == str else [""] * len(prompt)  # noqa: E721
108
109
110

        device = self._execution_device
        regions = 0
Patrick von Platen's avatar
Patrick von Platen committed
111

112
113
        self.power = int(rp_args["power"]) if "power" in rp_args else 1

Patrick von Platen's avatar
Patrick von Platen committed
114
115
        prompts = prompt if type(prompt) == list else [prompt]  # noqa: E721
        n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt]  # noqa: E721
116
        self.batch = batch = num_images_per_prompt * len(prompts)
Patrick von Platen's avatar
Patrick von Platen committed
117
118
        all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
        all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
119
120
121
122
123

        cn = len(all_prompts_cn) == len(all_n_prompts_cn)

        if Compel:
            compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
Patrick von Platen's avatar
Patrick von Platen committed
124

125
126
127
128
129
            def getcompelembs(prps):
                embl = []
                for prp in prps:
                    embl.append(compel.build_conditioning_tensor(prp))
                return torch.cat(embl)
Patrick von Platen's avatar
Patrick von Platen committed
130

131
            conds = getcompelembs(all_prompts_cn)
Patrick von Platen's avatar
Patrick von Platen committed
132
            unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
133
134
135
136
137
            embs = getcompelembs(prompts)
            n_embs = getcompelembs(n_prompts)
            prompt = negative_prompt = None
        else:
            conds = self.encode_prompt(prompts, device, 1, True)[0]
Patrick von Platen's avatar
Patrick von Platen committed
138
139
140
141
142
            unconds = (
                self.encode_prompt(n_prompts, device, 1, True)[0]
                if cn
                else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
            )
143
144
145
146
147
148
            embs = n_embs = None

        if not active:
            pcallback = None
            mode = None
        else:
Patrick von Platen's avatar
Patrick von Platen committed
149
150
151
152
            if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
                mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
                ocells, icells, regions = make_cells(rp_args["div"])

153
154
155
156
157
158
159
            elif "PRO" in rp_args["mode"].upper():
                regions = len(all_prompts_p[0])
                mode = "PROMPT"
                reset_attnmaps(self)
                self.ex = "EX" in rp_args["mode"].upper()
                self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
                thresholds = [float(x) for x in rp_args["th"].split(",")]
Patrick von Platen's avatar
Patrick von Platen committed
160
161

            orig_hw = (height, width)
162
163
            revers = True

Patrick von Platen's avatar
Patrick von Platen committed
164
            def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
165
166
                if "PRO" in mode:  # in Prompt mode, make masks from sum of attension maps
                    self.step = step
Patrick von Platen's avatar
Patrick von Platen committed
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                    if len(self.attnmaps_sizes) > 3:
                        self.history[step] = self.attnmaps.copy()
                        for hw in self.attnmaps_sizes:
                            allmasks = []
                            basemasks = [None] * batch
                            for tt, th in zip(target_tokens, thresholds):
                                for b in range(batch):
                                    key = f"{tt}-{b}"
                                    _, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
                                    mask = mask.unsqueeze(0).unsqueeze(-1)
                                    if self.ex:
                                        allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
                                        allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
                                    allmasks.append(mask)
                                    basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
Patrick von Platen's avatar
Patrick von Platen committed
183
                            basemasks = [1 - mask for mask in basemasks]
184
185
186
187
188
189
190
191
                            basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
                            allmasks = basemasks + allmasks

                            self.attnmasks[hw] = torch.cat(allmasks)
                        self.maskready = True
                return latents

            def hook_forward(module):
Patrick von Platen's avatar
Patrick von Platen committed
192
                # diffusers==0.23.2
193
194
195
196
197
198
199
                def forward(
                    hidden_states: torch.FloatTensor,
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
                    attention_mask: Optional[torch.FloatTensor] = None,
                    temb: Optional[torch.FloatTensor] = None,
                    scale: float = 1.0,
                ) -> torch.Tensor:
Patrick von Platen's avatar
Patrick von Platen committed
200
                    attn = module
201
                    xshape = hidden_states.shape
Patrick von Platen's avatar
Patrick von Platen committed
202
                    self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
203
204

                    if revers:
Patrick von Platen's avatar
Patrick von Platen committed
205
                        nx, px = hidden_states.chunk(2)
206
                    else:
Patrick von Platen's avatar
Patrick von Platen committed
207
                        px, nx = hidden_states.chunk(2)
208
209

                    if cn:
Patrick von Platen's avatar
Patrick von Platen committed
210
211
                        hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
                        encoder_hidden_states = torch.cat([conds] + [unconds])
212
                    else:
Patrick von Platen's avatar
Patrick von Platen committed
213
214
                        hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
                        encoder_hidden_states = torch.cat([conds] + [unconds])
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

                    residual = hidden_states

                    args = () if USE_PEFT_BACKEND else (scale,)

                    if attn.spatial_norm is not None:
                        hidden_states = attn.spatial_norm(hidden_states, temb)

                    input_ndim = hidden_states.ndim

                    if input_ndim == 4:
                        batch_size, channel, height, width = hidden_states.shape
                        hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

                    batch_size, sequence_length, _ = (
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
                    )

                    if attention_mask is not None:
                        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
                        attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

                    if attn.group_norm is not None:
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

                    args = () if USE_PEFT_BACKEND else (scale,)
                    query = attn.to_q(hidden_states, *args)

                    if encoder_hidden_states is None:
                        encoder_hidden_states = hidden_states
                    elif attn.norm_cross:
                        encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

                    key = attn.to_k(encoder_hidden_states, *args)
                    value = attn.to_v(encoder_hidden_states, *args)

                    inner_dim = key.shape[-1]
                    head_dim = inner_dim // attn.heads

                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
                    # TODO: add support for attn.scale when we move to Torch 2.1
                    hidden_states = scaled_dot_product_attention(
Patrick von Platen's avatar
Patrick von Platen committed
262
263
264
265
266
267
268
269
                        self,
                        query,
                        key,
                        value,
                        attn_mask=attention_mask,
                        dropout_p=0.0,
                        is_causal=False,
                        getattn="PRO" in mode,
270
271
272
273
                    )

                    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
                    hidden_states = hidden_states.to(query.dtype)
Patrick von Platen's avatar
Patrick von Platen committed
274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                    # linear proj
                    hidden_states = attn.to_out[0](hidden_states, *args)
                    # dropout
                    hidden_states = attn.to_out[1](hidden_states)

                    if input_ndim == 4:
                        hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

                    if attn.residual_connection:
                        hidden_states = hidden_states + residual

                    hidden_states = hidden_states / attn.rescale_output_factor

                    #### Regional Prompting Col/Row mode
                    if any(x in mode for x in ["COL", "ROW"]):
                        reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
                        center = reshaped.shape[0] // 2
                        px = reshaped[0:center] if cn else reshaped[0:-batch]
                        nx = reshaped[center:] if cn else reshaped[-batch:]
Patrick von Platen's avatar
Patrick von Platen committed
294
                        outs = [px, nx] if cn else [px]
295
296
                        for out in outs:
                            c = 0
Patrick von Platen's avatar
Patrick von Platen committed
297
                            for i, ocell in enumerate(ocells):
298
299
                                for icell in icells[i]:
                                    if "ROW" in mode:
Patrick von Platen's avatar
Patrick von Platen committed
300
301
302
303
304
305
306
307
308
309
310
                                        out[
                                            0:batch,
                                            int(h * ocell[0]) : int(h * ocell[1]),
                                            int(w * icell[0]) : int(w * icell[1]),
                                            :,
                                        ] = out[
                                            c * batch : (c + 1) * batch,
                                            int(h * ocell[0]) : int(h * ocell[1]),
                                            int(w * icell[0]) : int(w * icell[1]),
                                            :,
                                        ]
311
                                    else:
Patrick von Platen's avatar
Patrick von Platen committed
312
313
314
315
316
317
318
319
320
321
322
                                        out[
                                            0:batch,
                                            int(h * icell[0]) : int(h * icell[1]),
                                            int(w * ocell[0]) : int(w * ocell[1]),
                                            :,
                                        ] = out[
                                            c * batch : (c + 1) * batch,
                                            int(h * icell[0]) : int(h * icell[1]),
                                            int(w * ocell[0]) : int(w * ocell[1]),
                                            :,
                                        ]
323
324
                                    c += 1
                        px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
Patrick von Platen's avatar
Patrick von Platen committed
325
                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
326
327
328
329
330
331
332
                        hidden_states = hidden_states.reshape(xshape)

                    #### Regional Prompting Prompt mode
                    elif "PRO" in mode:
                        center = reshaped.shape[0] // 2
                        px = reshaped[0:center] if cn else reshaped[0:-batch]
                        nx = reshaped[center:] if cn else reshaped[-batch:]
Patrick von Platen's avatar
Patrick von Platen committed
333
334
335

                        if (h, w) in self.attnmasks and self.maskready:

336
                            def mask(input):
Patrick von Platen's avatar
Patrick von Platen committed
337
                                out = torch.multiply(input, self.attnmasks[(h, w)])
338
339
340
341
                                for b in range(batch):
                                    for r in range(1, regions):
                                        out[b] = out[b] + out[r * batch + b]
                                return out
Patrick von Platen's avatar
Patrick von Platen committed
342

343
344
                            px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
                        px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
Patrick von Platen's avatar
Patrick von Platen committed
345
                        hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
                    return hidden_states

                return forward

            def hook_forwards(root_module: torch.nn.Module):
                for name, module in root_module.named_modules():
                    if "attn2" in name and module.__class__.__name__ == "Attention":
                        module.forward = hook_forward(module)

            hook_forwards(self.unet)

        output = StableDiffusionPipeline(**self.components)(
            prompt=prompt,
            prompt_embeds=embs,
            negative_prompt=negative_prompt,
            negative_prompt_embeds=n_embs,
            height=height,
            width=width,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
            eta=eta,
            generator=generator,
            latents=latents,
            output_type=output_type,
            return_dict=return_dict,
Patrick von Platen's avatar
Patrick von Platen committed
372
            callback_on_step_end=pcallback,
373
374
375
376
377
378
379
        )

        if "save_mask" in rp_args:
            save_mask = rp_args["save_mask"]
        else:
            save_mask = False

Patrick von Platen's avatar
Patrick von Platen committed
380
381
        if mode == "PROMPT" and save_mask:
            saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
382
383
384
385
386

        return output


### Make prompt list for each regions
Patrick von Platen's avatar
Patrick von Platen committed
387
def promptsmaker(prompts, batch):
388
389
390
391
392
393
394
395
396
    out_p = []
    plen = len(prompts)
    for prompt in prompts:
        add = ""
        if KCOMM in prompt:
            add, prompt = prompt.split(KCOMM)
            add = add + " "
        prompts = prompt.split(KBRK)
        out_p.append([add + p for p in prompts])
Patrick von Platen's avatar
Patrick von Platen committed
397
398
399
    out = [None] * batch * len(out_p[0]) * len(out_p)
    for p, prs in enumerate(out_p):  # inputs prompts
        for r, pr in enumerate(prs):  # prompts for regions
400
            start = (p + r * plen) * batch
Patrick von Platen's avatar
Patrick von Platen committed
401
            out[start : start + batch] = [pr] * batch  # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
402
403
    return out, out_p

Patrick von Platen's avatar
Patrick von Platen committed
404

405
406
407
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
Patrick von Platen's avatar
Patrick von Platen committed
408
409
    if ";" not in ratios and "," in ratios:
        ratios = ratios.replace(",", ";")
410
411
412
413
414
415
    ratios = ratios.split(";")
    ratios = [inratios.split(",") for inratios in ratios]

    icells = []
    ocells = []

Patrick von Platen's avatar
Patrick von Platen committed
416
    def startend(cells, array):
417
418
419
420
421
422
423
        current_start = 0
        array = [float(x) for x in array]
        for value in array:
            end = current_start + (value / sum(array))
            cells.append([current_start, end])
            current_start = end

Patrick von Platen's avatar
Patrick von Platen committed
424
    startend(ocells, [r[0] for r in ratios])
425
426
427

    for inratios in ratios:
        if 2 > len(inratios):
Patrick von Platen's avatar
Patrick von Platen committed
428
            icells.append([[0, 1]])
429
430
        else:
            add = []
Patrick von Platen's avatar
Patrick von Platen committed
431
            startend(add, inratios[1:])
432
433
434
435
            icells.append(add)

    return ocells, icells, sum(len(cell) for cell in icells)

Patrick von Platen's avatar
Patrick von Platen committed
436

437
438
def make_emblist(self, prompts):
    with torch.no_grad():
Patrick von Platen's avatar
Patrick von Platen committed
439
440
441
442
        tokens = self.tokenizer(
            prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
        ).input_ids.to(self.device)
        embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
443
    return embs
Patrick von Platen's avatar
Patrick von Platen committed
444

445
446
447

def split_dims(xs, height, width):
    xs = xs
Patrick von Platen's avatar
Patrick von Platen committed
448
449

    def repeat_div(x, y):
450
451
452
453
        while y > 0:
            x = math.ceil(x / 2)
            y = y - 1
        return x
Patrick von Platen's avatar
Patrick von Platen committed
454

455
    scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
Patrick von Platen's avatar
Patrick von Platen committed
456
457
458
459
    dsh = repeat_div(height, scale)
    dsw = repeat_div(width, scale)
    return dsh, dsw

460
461

##### for prompt mode
Patrick von Platen's avatar
Patrick von Platen committed
462
463
def get_attn_maps(self, attn):
    height, width = self.hw
464
    target_tokens = self.target_tokens
Patrick von Platen's avatar
Patrick von Platen committed
465
466
467
    if (height, width) not in self.attnmaps_sizes:
        self.attnmaps_sizes.append((height, width))

468
469
470
    for b in range(self.batch):
        for t in target_tokens:
            power = self.power
Patrick von Platen's avatar
Patrick von Platen committed
471
472
473
            add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
            add = torch.sum(add, dim=2)
            key = f"{t}-{b}"
474
475
476
477
            if key not in self.attnmaps:
                self.attnmaps[key] = add
            else:
                if self.attnmaps[key].shape[1] != add.shape[1]:
Patrick von Platen's avatar
Patrick von Platen committed
478
479
                    add = add.view(8, height, width)
                    add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
480
481
482
483
                    add = add.reshape_as(self.attnmaps[key])

                self.attnmaps[key] = self.attnmaps[key] + add

Patrick von Platen's avatar
Patrick von Platen committed
484
485

def reset_attnmaps(self):  # init parameters in every batch
486
    self.step = 0
Patrick von Platen's avatar
Patrick von Platen committed
487
488
489
    self.attnmaps = {}  # maked from attention maps
    self.attnmaps_sizes = []  # height,width set of u-net blocks
    self.attnmasks = {}  # maked from attnmaps for regions
490
491
492
    self.maskready = False
    self.history = {}

Patrick von Platen's avatar
Patrick von Platen committed
493
494

def saveattnmaps(self, output, h, w, th, step, regions):
495
496
    masks = []
    for i, mask in enumerate(self.history[step].values()):
Patrick von Platen's avatar
Patrick von Platen committed
497
        img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
498
499
500
501
502
503
504
505
506
        if self.ex:
            masks = [x - mask for x in masks]
            masks.append(mask)
            if len(masks) == regions - 1:
                output.images.extend([FF.to_pil_image(mask) for mask in masks])
                masks = []
        else:
            output.images.append(img)

Patrick von Platen's avatar
Patrick von Platen committed
507
508
509
510

def makepmask(
    self, mask, h, w, th, step
):  # make masks from attention cache return [for preview, for attention, for Latent]
511
    th = th - step * 0.005
Patrick von Platen's avatar
Patrick von Platen committed
512
513
514
    if 0.05 >= th:
        th = 0.05
    mask = torch.mean(mask, dim=0)
515
    mask = mask / mask.max().item()
Patrick von Platen's avatar
Patrick von Platen committed
516
    mask = torch.where(mask > th, 1, 0)
517
    mask = mask.float()
Patrick von Platen's avatar
Patrick von Platen committed
518
    mask = mask.view(1, *self.attnmaps_sizes[0])
519
    img = FF.to_pil_image(mask)
Patrick von Platen's avatar
Patrick von Platen committed
520
521
    img = img.resize((w, h))
    mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
522
    lmask = mask
Patrick von Platen's avatar
Patrick von Platen committed
523
524
    mask = mask.reshape(h * w)
    mask = torch.where(mask > 0.1, 1, 0)
525
526
    return img, mask, lmask

Patrick von Platen's avatar
Patrick von Platen committed
527

528
529
def tokendealer(self, all_prompts):
    for prompts in all_prompts:
Patrick von Platen's avatar
Patrick von Platen committed
530
        targets = [p.split(",")[-1] for p in prompts[1:]]
531
532
533
        tt = []

        for target in targets:
Patrick von Platen's avatar
Patrick von Platen committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
            ptokens = (
                self.tokenizer(
                    prompts,
                    max_length=self.tokenizer.model_max_length,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids
            )[0]
            ttokens = (
                self.tokenizer(
                    target,
                    max_length=self.tokenizer.model_max_length,
                    padding=True,
                    truncation=True,
                    return_tensors="pt",
                ).input_ids
            )[0]
552
553
554

            tlist = []

Patrick von Platen's avatar
Patrick von Platen committed
555
            for t in range(ttokens.shape[0] - 2):
556
557
558
                for p in range(ptokens.shape[0]):
                    if ttokens[t + 1] == ptokens[p]:
                        tlist.append(p)
Patrick von Platen's avatar
Patrick von Platen committed
559
560
            if tlist != []:
                tt.append(tlist)
561
562
563

    return tt

Patrick von Platen's avatar
Patrick von Platen committed
564
565
566
567

def scaled_dot_product_attention(
    self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
) -> torch.Tensor:
568
569
570
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
Patrick von Platen's avatar
Patrick von Platen committed
571
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
572
573
574
575
576
577
578
579
580
581
582
583
584
585
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
Patrick von Platen's avatar
Patrick von Platen committed
586
587
    if getattn:
        get_attn_maps(self, attn_weight)
588
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
Patrick von Platen's avatar
Patrick von Platen committed
589
    return attn_weight @ value