"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9f669e7b5d212f74f3fb1183af178435d785c0b4"
Unverified Commit ff43dba7 authored by hako-mikan's avatar hako-mikan Committed by GitHub
Browse files

[Fix] Fix Regional Prompting Pipeline (#6188)



* Update regional_prompting_stable_diffusion.py

* reformat

* reformat

* reformat

* reformat

* reformat

* reformat

* reformat

* regormat

* reformat

* reformat

* reformat

* reformat

* Update regional_prompting_stable_diffusion.py

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 54339629
...@@ -73,7 +73,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -73,7 +73,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__( super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
) )
self.register_modules( self.register_modules(
vae=vae, vae=vae,
...@@ -102,22 +109,22 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -102,22 +109,22 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return_dict: bool = True, return_dict: bool = True,
rp_args: Dict[str, str] = None, rp_args: Dict[str, str] = None,
): ):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721 active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
if negative_prompt is None: if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721 negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
device = self._execution_device device = self._execution_device
regions = 0 regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1 self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721 prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721 n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts) self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn) equal = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel: if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder) compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
...@@ -129,7 +136,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -129,7 +136,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return torch.cat(embl) return torch.cat(embl)
conds = getcompelembs(all_prompts_cn) conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts) unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts) embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts) n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None prompt = negative_prompt = None
...@@ -137,7 +144,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -137,7 +144,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
conds = self.encode_prompt(prompts, device, 1, True)[0] conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = ( unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0] self.encode_prompt(n_prompts, device, 1, True)[0]
if cn if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
) )
embs = n_embs = None embs = n_embs = None
...@@ -206,8 +213,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -206,8 +213,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else: else:
px, nx = hidden_states.chunk(2) px, nx = hidden_states.chunk(2)
if cn: if equal:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0) hidden_states = torch.cat(
[px for i in range(regions)] + [nx for i in range(regions)],
0,
)
encoder_hidden_states = torch.cat([conds] + [unconds]) encoder_hidden_states = torch.cat([conds] + [unconds])
else: else:
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0) hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
...@@ -289,9 +299,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -289,9 +299,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
if any(x in mode for x in ["COL", "ROW"]): if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2]) reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2 center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch] px = reshaped[0:center] if equal else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:] nx = reshaped[center:] if equal else reshaped[-batch:]
outs = [px, nx] if cn else [px] outs = [px, nx] if equal else [px]
for out in outs: for out in outs:
c = 0 c = 0
for i, ocell in enumerate(ocells): for i, ocell in enumerate(ocells):
...@@ -321,15 +331,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -321,15 +331,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
:, :,
] ]
c += 1 c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape) hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode #### Regional Prompting Prompt mode
elif "PRO" in mode: elif "PRO" in mode:
center = reshaped.shape[0] // 2 px, nx = (
px = reshaped[0:center] if cn else reshaped[0:-batch] torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
nx = reshaped[center:] if cn else reshaped[-batch:] hidden_states[-batch:],
)
if (h, w) in self.attnmasks and self.maskready: if (h, w) in self.attnmasks and self.maskready:
...@@ -340,8 +351,8 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -340,8 +351,8 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
out[b] = out[b] + out[r * batch + b] out[b] = out[b] + out[r * batch + b]
return out return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx) px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx) px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0) hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states return hidden_states
...@@ -378,7 +389,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -378,7 +389,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
save_mask = False save_mask = False
if mode == "PROMPT" and save_mask: if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions) saveattnmaps(
self,
output,
height,
width,
thresholds,
num_inference_steps // 2,
regions,
)
return output return output
...@@ -437,7 +456,11 @@ def make_cells(ratios): ...@@ -437,7 +456,11 @@ def make_cells(ratios):
def make_emblist(self, prompts): def make_emblist(self, prompts):
with torch.no_grad(): with torch.no_grad():
tokens = self.tokenizer( tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids.to(self.device) ).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype) embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs return embs
...@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts): ...@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):
def scaled_dot_product_attention( def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False self,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
getattn=False,
) -> torch.Tensor: ) -> torch.Tensor:
# Efficient implementation equivalent to the following: # Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2) L, S = query.size(-2), key.size(-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment