Commit 0f55c17e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix style

parent 5058d27f
import torchvision.transforms.functional as FF import math
import torch
import torchvision
from typing import Dict, Optional from typing import Dict, Optional
import torch
import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers.utils import USE_PEFT_BACKEND
try: try:
from compel import Compel from compel import Compel
except: except ImportError:
Compel = None Compel = None
KCOMM = "ADDCOMM" KCOMM = "ADDCOMM"
KBRK = "BREAK" KBRK = "BREAK"
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
r""" r"""
Args for Regional Prompting Pipeline: Args for Regional Prompting Pipeline:
rp_args:dict rp_args:dict
Required Required
rp_args["mode"]: cols, rows, prompt, prompt-ex rp_args["mode"]: cols, rows, prompt, prompt-ex
for cols, rows mode for cols, rows mode
rp_args["div"]: ex) 1;1;1(Divide into 3 regions) rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
for prompt, prompt-ex mode for prompt, prompt-ex mode
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode) rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Optional Optional
rp_args["save_mask"]: True/False (save masks in prompt mode) rp_args["save_mask"]: True/False (save masks in prompt mode)
...@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]): feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
def __init__( def __init__(
self, self,
vae: AutoencoderKL, vae: AutoencoderKL,
...@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker) super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
...@@ -93,50 +100,56 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -93,50 +100,56 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
rp_args:Dict[str,str] = None, rp_args: Dict[str, str] = None,
): ):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt if negative_prompt is None:
if negative_prompt is None: negative_prompt = "" if type(prompt) == str else [""] * len(prompt) negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
device = self._execution_device device = self._execution_device
regions = 0 regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1 self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt] prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
self.batch = batch = num_images_per_prompt * len(prompts) self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts,num_images_per_prompt) all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts,num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn) cn = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel: if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
def getcompelembs(prps): def getcompelembs(prps):
embl = [] embl = []
for prp in prps: for prp in prps:
embl.append(compel.build_conditioning_tensor(prp)) embl.append(compel.build_conditioning_tensor(prp))
return torch.cat(embl) return torch.cat(embl)
conds = getcompelembs(all_prompts_cn) conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
embs = getcompelembs(prompts) embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts) n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None prompt = negative_prompt = None
else: else:
conds = self.encode_prompt(prompts, device, 1, True)[0] conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = self.encode_prompt(n_prompts, device, 1, True)[0] if cn else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None embs = n_embs = None
if not active: if not active:
pcallback = None pcallback = None
mode = None mode = None
else: else:
if any(x in rp_args["mode"].upper() for x in ["COL","ROW"]): if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW" mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
ocells,icells,regions = make_cells(rp_args["div"]) ocells, icells, regions = make_cells(rp_args["div"])
elif "PRO" in rp_args["mode"].upper(): elif "PRO" in rp_args["mode"].upper():
regions = len(all_prompts_p[0]) regions = len(all_prompts_p[0])
mode = "PROMPT" mode = "PROMPT"
...@@ -144,14 +157,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -144,14 +157,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
self.ex = "EX" in rp_args["mode"].upper() self.ex = "EX" in rp_args["mode"].upper()
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p) self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
thresholds = [float(x) for x in rp_args["th"].split(",")] thresholds = [float(x) for x in rp_args["th"].split(",")]
orig_hw = (height,width) orig_hw = (height, width)
revers = True revers = True
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs=None): def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
self.step = step self.step = step
if len(self.attnmaps_sizes) > 3: if len(self.attnmaps_sizes) > 3:
self.history[step] = self.attnmaps.copy() self.history[step] = self.attnmaps.copy()
for hw in self.attnmaps_sizes: for hw in self.attnmaps_sizes:
...@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]] allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
allmasks.append(mask) allmasks.append(mask)
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
basemasks = [1 -mask for mask in basemasks] basemasks = [1 - mask for mask in basemasks]
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks] basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
allmasks = basemasks + allmasks allmasks = basemasks + allmasks
...@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return latents return latents
def hook_forward(module): def hook_forward(module):
#diffusers==0.23.2 # diffusers==0.23.2
def forward( def forward(
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
...@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
attn = module
attn = module
xshape = hidden_states.shape xshape = hidden_states.shape
self.hw = (h,w) = split_dims(xshape[1], *orig_hw) self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
if revers: if revers:
nx,px = hidden_states.chunk(2) nx, px = hidden_states.chunk(2)
else: else:
px,nx = hidden_states.chunk(2) px, nx = hidden_states.chunk(2)
if cn: if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)],0) hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
encoder_hidden_states = torch.cat([conds]+[unconds]) encoder_hidden_states = torch.cat([conds] + [unconds])
else: else:
hidden_states = torch.cat([px for i in range(regions)] + [nx],0) hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
encoder_hidden_states = torch.cat([conds]+[unconds]) encoder_hidden_states = torch.cat([conds] + [unconds])
residual = hidden_states residual = hidden_states
...@@ -247,12 +259,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -247,12 +259,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = scaled_dot_product_attention( hidden_states = scaled_dot_product_attention(
self, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, getattn = "PRO" in mode self,
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
getattn="PRO" in mode,
) )
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states, *args)
# dropout # dropout
...@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center = reshaped.shape[0] // 2 center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch] px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:] nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px,nx] if cn else [px] outs = [px, nx] if cn else [px]
for out in outs: for out in outs:
c = 0 c = 0
for i,ocell in enumerate(ocells): for i, ocell in enumerate(ocells):
for icell in icells[i]: for icell in icells[i]:
if "ROW" in mode: if "ROW" in mode:
out[0:batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] = out[c*batch:(c+1)*batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] out[
0:batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
]
else: else:
out[0:batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] = out[c*batch:(c+1)*batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] out[
0:batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
]
c += 1 c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape) hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode #### Regional Prompting Prompt mode
...@@ -291,17 +330,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -291,17 +330,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center = reshaped.shape[0] // 2 center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch] px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:] nx = reshaped[center:] if cn else reshaped[-batch:]
if (h,w) in self.attnmasks and self.maskready: if (h, w) in self.attnmasks and self.maskready:
def mask(input): def mask(input):
out = torch.multiply(input,self.attnmasks[(h,w)]) out = torch.multiply(input, self.attnmasks[(h, w)])
for b in range(batch): for b in range(batch):
for r in range(1, regions): for r in range(1, regions):
out[b] = out[b] + out[r * batch + b] out[b] = out[b] + out[r * batch + b]
return out return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx) px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states return hidden_states
return forward return forward
...@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents=latents, latents=latents,
output_type=output_type, output_type=output_type,
return_dict=return_dict, return_dict=return_dict,
callback_on_step_end = pcallback callback_on_step_end=pcallback,
) )
if "save_mask" in rp_args: if "save_mask" in rp_args:
...@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else: else:
save_mask = False save_mask = False
if mode == "PROMPT" and save_mask: saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
return output return output
### Make prompt list for each regions ### Make prompt list for each regions
def promptsmaker(prompts,batch): def promptsmaker(prompts, batch):
out_p = [] out_p = []
plen = len(prompts) plen = len(prompts)
for prompt in prompts: for prompt in prompts:
...@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch): ...@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
add = add + " " add = add + " "
prompts = prompt.split(KBRK) prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts]) out_p.append([add + p for p in prompts])
out = [None]*batch*len(out_p[0]) * len(out_p) out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions for r, pr in enumerate(prs): # prompts for regions
start = (p + r * plen) * batch start = (p + r * plen) * batch
out[start : start + batch]= [pr] * batch #P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1... out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return out, out_p return out, out_p
### make regions from ratios ### make regions from ratios
### ";" makes outercells, "," makes inner cells ### ";" makes outercells, "," makes inner cells
def make_cells(ratios): def make_cells(ratios):
if ";" not in ratios and "," in ratios:ratios = ratios.replace(",",";") if ";" not in ratios and "," in ratios:
ratios = ratios.replace(",", ";")
ratios = ratios.split(";") ratios = ratios.split(";")
ratios = [inratios.split(",") for inratios in ratios] ratios = [inratios.split(",") for inratios in ratios]
icells = [] icells = []
ocells = [] ocells = []
def startend(cells,array): def startend(cells, array):
current_start = 0 current_start = 0
array = [float(x) for x in array] array = [float(x) for x in array]
for value in array: for value in array:
...@@ -377,72 +421,80 @@ def make_cells(ratios): ...@@ -377,72 +421,80 @@ def make_cells(ratios):
cells.append([current_start, end]) cells.append([current_start, end])
current_start = end current_start = end
startend(ocells,[r[0] for r in ratios]) startend(ocells, [r[0] for r in ratios])
for inratios in ratios: for inratios in ratios:
if 2 > len(inratios): if 2 > len(inratios):
icells.append([[0,1]]) icells.append([[0, 1]])
else: else:
add = [] add = []
startend(add,inratios[1:]) startend(add, inratios[1:])
icells.append(add) icells.append(add)
return ocells, icells, sum(len(cell) for cell in icells) return ocells, icells, sum(len(cell) for cell in icells)
def make_emblist(self, prompts): def make_emblist(self, prompts):
with torch.no_grad(): with torch.no_grad():
tokens = self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids.to(self.device) tokens = self.tokenizer(
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype = self.dtype) prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs return embs
import math
def split_dims(xs, height, width): def split_dims(xs, height, width):
xs = xs xs = xs
def repeat_div(x,y):
def repeat_div(x, y):
while y > 0: while y > 0:
x = math.ceil(x / 2) x = math.ceil(x / 2)
y = y - 1 y = y - 1
return x return x
scale = math.ceil(math.log2(math.sqrt(height * width / xs))) scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
dsh = repeat_div(height,scale) dsh = repeat_div(height, scale)
dsw = repeat_div(width,scale) dsw = repeat_div(width, scale)
return dsh,dsw return dsh, dsw
##### for prompt mode ##### for prompt mode
def get_attn_maps(self,attn): def get_attn_maps(self, attn):
height,width = self.hw height, width = self.hw
target_tokens = self.target_tokens target_tokens = self.target_tokens
if (height,width) not in self.attnmaps_sizes: if (height, width) not in self.attnmaps_sizes:
self.attnmaps_sizes.append((height,width)) self.attnmaps_sizes.append((height, width))
for b in range(self.batch): for b in range(self.batch):
for t in target_tokens: for t in target_tokens:
power = self.power power = self.power
add = attn[b,:,:,t[0]:t[0]+len(t)]**(power)*(self.attnmaps_sizes.index((height,width)) + 1) add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
add = torch.sum(add,dim = 2) add = torch.sum(add, dim=2)
key = f"{t}-{b}" key = f"{t}-{b}"
if key not in self.attnmaps: if key not in self.attnmaps:
self.attnmaps[key] = add self.attnmaps[key] = add
else: else:
if self.attnmaps[key].shape[1] != add.shape[1]: if self.attnmaps[key].shape[1] != add.shape[1]:
add = add.view(8,height,width) add = add.view(8, height, width)
add = FF.resize(add,self.attnmaps_sizes[0],antialias=None) add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
add = add.reshape_as(self.attnmaps[key]) add = add.reshape_as(self.attnmaps[key])
self.attnmaps[key] = self.attnmaps[key] + add self.attnmaps[key] = self.attnmaps[key] + add
def reset_attnmaps(self): # init parameters in every batch
def reset_attnmaps(self): # init parameters in every batch
self.step = 0 self.step = 0
self.attnmaps = {} #maked from attention maps self.attnmaps = {} # maked from attention maps
self.attnmaps_sizes =[] #height,width set of u-net blocks self.attnmaps_sizes = [] # height,width set of u-net blocks
self.attnmasks = {} #maked from attnmaps for regions self.attnmasks = {} # maked from attnmaps for regions
self.maskready = False self.maskready = False
self.history = {} self.history = {}
def saveattnmaps(self,output,h,w,th,step,regions):
def saveattnmaps(self, output, h, w, th, step, regions):
masks = [] masks = []
for i, mask in enumerate(self.history[step].values()): for i, mask in enumerate(self.history[step].values()):
img, _ , mask = makepmask(self, mask, h, w, th[i % len(th)], step) img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
if self.ex: if self.ex:
masks = [x - mask for x in masks] masks = [x - mask for x in masks]
masks.append(mask) masks.append(mask)
...@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions): ...@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
else: else:
output.images.append(img) output.images.append(img)
def makepmask(self, mask, h, w, th, step): # make masks from attention cache return [for preview, for attention, for Latent]
def makepmask(
self, mask, h, w, th, step
): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005 th = th - step * 0.005
if 0.05 >= th: th = 0.05 if 0.05 >= th:
mask = torch.mean(mask,dim=0) th = 0.05
mask = torch.mean(mask, dim=0)
mask = mask / mask.max().item() mask = mask / mask.max().item()
mask = torch.where(mask > th ,1,0) mask = torch.where(mask > th, 1, 0)
mask = mask.float() mask = mask.float()
mask = mask.view(1,*self.attnmaps_sizes[0]) mask = mask.view(1, *self.attnmaps_sizes[0])
img = FF.to_pil_image(mask) img = FF.to_pil_image(mask)
img = img.resize((w,h)) img = img.resize((w, h))
mask = FF.resize(mask,(h,w),interpolation=FF.InterpolationMode.NEAREST,antialias=None) mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
lmask = mask lmask = mask
mask = mask.reshape(h*w) mask = mask.reshape(h * w)
mask = torch.where(mask > 0.1 ,1,0) mask = torch.where(mask > 0.1, 1, 0)
return img, mask, lmask return img, mask, lmask
def tokendealer(self, all_prompts): def tokendealer(self, all_prompts):
for prompts in all_prompts: for prompts in all_prompts:
targets =[p.split(",")[-1] for p in prompts[1:]] targets = [p.split(",")[-1] for p in prompts[1:]]
tt = [] tt = []
for target in targets: for target in targets:
ptokens = (self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0] ptokens = (
ttokens = (self.tokenizer(target, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0] self.tokenizer(
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
ttokens = (
self.tokenizer(
target,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
tlist = [] tlist = []
for t in range(ttokens.shape[0] -2): for t in range(ttokens.shape[0] - 2):
for p in range(ptokens.shape[0]): for p in range(ptokens.shape[0]):
if ttokens[t + 1] == ptokens[p]: if ttokens[t + 1] == ptokens[p]:
tlist.append(p) tlist.append(p)
if tlist != [] : tt.append(tlist) if tlist != []:
tt.append(tlist)
return tt return tt
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn = False) -> torch.Tensor:
def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
) -> torch.Tensor:
# Efficient implementation equivalent to the following: # Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2) L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype,device=self.device) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
if is_causal: if is_causal:
assert attn_mask is None assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
...@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou ...@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.softmax(attn_weight, dim=-1)
if getattn: get_attn_maps(self,attn_weight) if getattn:
get_attn_maps(self, attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value return attn_weight @ value
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment