Commit 0f55c17e authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix style

parent 5058d27f
import torchvision.transforms.functional as FF
import torch
import torchvision
import math
from typing import Dict, Optional
import torch
import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import USE_PEFT_BACKEND
try:
from compel import Compel
except:
except ImportError:
Compel = None
KCOMM = "ADDCOMM"
KBRK = "BREAK"
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
r"""
Args for Regional Prompting Pipeline:
rp_args:dict
Required
Required
rp_args["mode"]: cols, rows, prompt, prompt-ex
for cols, rows mode
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
for prompt, prompt-ex mode
for prompt, prompt-ex mode
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
......@@ -56,6 +60,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
......@@ -67,7 +72,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker)
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
......@@ -93,50 +100,56 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
rp_args:Dict[str,str] = None,
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt
if negative_prompt is None: negative_prompt = "" if type(prompt) == str else [""] * len(prompt)
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
device = self._execution_device
regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt]
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt]
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts,num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts,num_images_per_prompt)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
def getcompelembs(prps):
embl = []
for prp in prps:
embl.append(compel.build_conditioning_tensor(prp))
return torch.cat(embl)
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = self.encode_prompt(n_prompts, device, 1, True)[0] if cn else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None
if not active:
pcallback = None
mode = None
else:
if any(x in rp_args["mode"].upper() for x in ["COL","ROW"]):
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
ocells,icells,regions = make_cells(rp_args["div"])
if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
ocells, icells, regions = make_cells(rp_args["div"])
elif "PRO" in rp_args["mode"].upper():
regions = len(all_prompts_p[0])
mode = "PROMPT"
......@@ -144,14 +157,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
self.ex = "EX" in rp_args["mode"].upper()
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
thresholds = [float(x) for x in rp_args["th"].split(",")]
orig_hw = (height,width)
orig_hw = (height, width)
revers = True
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs=None):
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
self.step = step
if len(self.attnmaps_sizes) > 3:
self.history[step] = self.attnmaps.copy()
for hw in self.attnmaps_sizes:
......@@ -167,7 +180,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
allmasks.append(mask)
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
basemasks = [1 -mask for mask in basemasks]
basemasks = [1 - mask for mask in basemasks]
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
allmasks = basemasks + allmasks
......@@ -176,7 +189,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return latents
def hook_forward(module):
#diffusers==0.23.2
# diffusers==0.23.2
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
......@@ -184,22 +197,21 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
attn = module
attn = module
xshape = hidden_states.shape
self.hw = (h,w) = split_dims(xshape[1], *orig_hw)
self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
if revers:
nx,px = hidden_states.chunk(2)
nx, px = hidden_states.chunk(2)
else:
px,nx = hidden_states.chunk(2)
px, nx = hidden_states.chunk(2)
if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)],0)
encoder_hidden_states = torch.cat([conds]+[unconds])
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
encoder_hidden_states = torch.cat([conds] + [unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx],0)
encoder_hidden_states = torch.cat([conds]+[unconds])
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
encoder_hidden_states = torch.cat([conds] + [unconds])
residual = hidden_states
......@@ -247,12 +259,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = scaled_dot_product_attention(
self, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, getattn = "PRO" in mode
self,
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
getattn="PRO" in mode,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
......@@ -272,18 +291,38 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px,nx] if cn else [px]
outs = [px, nx] if cn else [px]
for out in outs:
c = 0
for i,ocell in enumerate(ocells):
for i, ocell in enumerate(ocells):
for icell in icells[i]:
if "ROW" in mode:
out[0:batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] = out[c*batch:(c+1)*batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:]
out[
0:batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
]
else:
out[0:batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] = out[c*batch:(c+1)*batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:]
out[
0:batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode
......@@ -291,17 +330,19 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
if (h,w) in self.attnmasks and self.maskready:
if (h, w) in self.attnmasks and self.maskready:
def mask(input):
out = torch.multiply(input,self.attnmasks[(h,w)])
out = torch.multiply(input, self.attnmasks[(h, w)])
for b in range(batch):
for r in range(1, regions):
out[b] = out[b] + out[r * batch + b]
return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states
return forward
......@@ -328,7 +369,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback_on_step_end = pcallback
callback_on_step_end=pcallback,
)
if "save_mask" in rp_args:
......@@ -336,13 +377,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else:
save_mask = False
if mode == "PROMPT" and save_mask: saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
return output
### Make prompt list for each regions
def promptsmaker(prompts,batch):
def promptsmaker(prompts, batch):
out_p = []
plen = len(prompts)
for prompt in prompts:
......@@ -352,24 +394,26 @@ def promptsmaker(prompts,batch):
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
out = [None]*batch*len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
start = (p + r * plen) * batch
out[start : start + batch]= [pr] * batch #P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return out, out_p
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
if ";" not in ratios and "," in ratios:ratios = ratios.replace(",",";")
if ";" not in ratios and "," in ratios:
ratios = ratios.replace(",", ";")
ratios = ratios.split(";")
ratios = [inratios.split(",") for inratios in ratios]
icells = []
ocells = []
def startend(cells,array):
def startend(cells, array):
current_start = 0
array = [float(x) for x in array]
for value in array:
......@@ -377,72 +421,80 @@ def make_cells(ratios):
cells.append([current_start, end])
current_start = end
startend(ocells,[r[0] for r in ratios])
startend(ocells, [r[0] for r in ratios])
for inratios in ratios:
if 2 > len(inratios):
icells.append([[0,1]])
icells.append([[0, 1]])
else:
add = []
startend(add,inratios[1:])
startend(add, inratios[1:])
icells.append(add)
return ocells, icells, sum(len(cell) for cell in icells)
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype = self.dtype)
tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs
import math
def split_dims(xs, height, width):
xs = xs
def repeat_div(x,y):
def repeat_div(x, y):
while y > 0:
x = math.ceil(x / 2)
y = y - 1
return x
scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
dsh = repeat_div(height,scale)
dsw = repeat_div(width,scale)
return dsh,dsw
dsh = repeat_div(height, scale)
dsw = repeat_div(width, scale)
return dsh, dsw
##### for prompt mode
def get_attn_maps(self,attn):
height,width = self.hw
def get_attn_maps(self, attn):
height, width = self.hw
target_tokens = self.target_tokens
if (height,width) not in self.attnmaps_sizes:
self.attnmaps_sizes.append((height,width))
if (height, width) not in self.attnmaps_sizes:
self.attnmaps_sizes.append((height, width))
for b in range(self.batch):
for t in target_tokens:
power = self.power
add = attn[b,:,:,t[0]:t[0]+len(t)]**(power)*(self.attnmaps_sizes.index((height,width)) + 1)
add = torch.sum(add,dim = 2)
key = f"{t}-{b}"
add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
add = torch.sum(add, dim=2)
key = f"{t}-{b}"
if key not in self.attnmaps:
self.attnmaps[key] = add
else:
if self.attnmaps[key].shape[1] != add.shape[1]:
add = add.view(8,height,width)
add = FF.resize(add,self.attnmaps_sizes[0],antialias=None)
add = add.view(8, height, width)
add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
add = add.reshape_as(self.attnmaps[key])
self.attnmaps[key] = self.attnmaps[key] + add
def reset_attnmaps(self): # init parameters in every batch
def reset_attnmaps(self): # init parameters in every batch
self.step = 0
self.attnmaps = {} #maked from attention maps
self.attnmaps_sizes =[] #height,width set of u-net blocks
self.attnmasks = {} #maked from attnmaps for regions
self.attnmaps = {} # maked from attention maps
self.attnmaps_sizes = [] # height,width set of u-net blocks
self.attnmasks = {} # maked from attnmaps for regions
self.maskready = False
self.history = {}
def saveattnmaps(self,output,h,w,th,step,regions):
def saveattnmaps(self, output, h, w, th, step, regions):
masks = []
for i, mask in enumerate(self.history[step].values()):
img, _ , mask = makepmask(self, mask, h, w, th[i % len(th)], step)
img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
if self.ex:
masks = [x - mask for x in masks]
masks.append(mask)
......@@ -452,46 +504,71 @@ def saveattnmaps(self,output,h,w,th,step,regions):
else:
output.images.append(img)
def makepmask(self, mask, h, w, th, step): # make masks from attention cache return [for preview, for attention, for Latent]
def makepmask(
self, mask, h, w, th, step
): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005
if 0.05 >= th: th = 0.05
mask = torch.mean(mask,dim=0)
if 0.05 >= th:
th = 0.05
mask = torch.mean(mask, dim=0)
mask = mask / mask.max().item()
mask = torch.where(mask > th ,1,0)
mask = torch.where(mask > th, 1, 0)
mask = mask.float()
mask = mask.view(1,*self.attnmaps_sizes[0])
mask = mask.view(1, *self.attnmaps_sizes[0])
img = FF.to_pil_image(mask)
img = img.resize((w,h))
mask = FF.resize(mask,(h,w),interpolation=FF.InterpolationMode.NEAREST,antialias=None)
img = img.resize((w, h))
mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
lmask = mask
mask = mask.reshape(h*w)
mask = torch.where(mask > 0.1 ,1,0)
mask = mask.reshape(h * w)
mask = torch.where(mask > 0.1, 1, 0)
return img, mask, lmask
def tokendealer(self, all_prompts):
for prompts in all_prompts:
targets =[p.split(",")[-1] for p in prompts[1:]]
targets = [p.split(",")[-1] for p in prompts[1:]]
tt = []
for target in targets:
ptokens = (self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0]
ttokens = (self.tokenizer(target, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0]
ptokens = (
self.tokenizer(
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
ttokens = (
self.tokenizer(
target,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
tlist = []
for t in range(ttokens.shape[0] -2):
for t in range(ttokens.shape[0] - 2):
for p in range(ptokens.shape[0]):
if ttokens[t + 1] == ptokens[p]:
tlist.append(p)
if tlist != [] : tt.append(tlist)
if tlist != []:
tt.append(tlist)
return tt
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn = False) -> torch.Tensor:
def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype,device=self.device)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
......@@ -506,6 +583,7 @@ def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropou
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if getattn: get_attn_maps(self,attn_weight)
if getattn:
get_attn_maps(self, attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
\ No newline at end of file
return attn_weight @ value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment