"...text-generation-inference.git" did not exist on "c38a7d7ddd9c612e368adec1ef94583be602fc7e"
Unverified Commit 69c83d6e authored by cjkangme's avatar cjkangme Committed by GitHub
Browse files

[Community Pipeline] Add some feature for regional prompting pipeline (#9874)



* [Fix] fix bugs of  regional_prompting pipeline

* [Feat] add base prompt feature

* [Fix] fix __init__ pipeline error

* [Fix] delete unused args

* [Fix] improve string handling

* [Docs] docs to use_base in regional_prompting

* make style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e44fc75a
...@@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK ...@@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit best quality, 3persons in garden, an old man red suit
``` ```
### Use base prompt
You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.
```
2d animation style ADDBASE
masterpiece, high quality ADDCOMM
(blue sky)++ BREAK
green hair twintail BREAK
book shelf BREAK
messy desk BREAK
orange++ dress and sofa
```
### Negative prompt ### Negative prompt
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions. Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
...@@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args) ...@@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args)
### Optional Parameters ### Optional Parameters
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`. - `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed. The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
......
...@@ -3,13 +3,12 @@ from typing import Dict, Optional ...@@ -3,13 +3,12 @@ from typing import Dict, Optional
import torch import torch
import torchvision.transforms.functional as FF import torchvision.transforms.functional as FF
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND
try: try:
...@@ -17,6 +16,7 @@ try: ...@@ -17,6 +16,7 @@ try:
except ImportError: except ImportError:
Compel = None Compel = None
KBASE = "ADDBASE"
KCOMM = "ADDCOMM" KCOMM = "ADDCOMM"
KBRK = "BREAK" KBRK = "BREAK"
...@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
Optional Optional
rp_args["save_mask"]: True/False (save masks in prompt mode) rp_args["save_mask"]: True/False (save masks in prompt mode)
rp_args["power"]: int (power for attention maps in prompt mode)
rp_args["base_ratio"]:
float (Sets the ratio of the base prompt)
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
Pipeline for text-to-image generation using Stable Diffusion. Pipeline for text-to-image generation using Stable Diffusion.
...@@ -70,6 +75,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -70,6 +75,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor, feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__( super().__init__(
...@@ -80,6 +86,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -80,6 +86,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
scheduler, scheduler,
safety_checker, safety_checker,
feature_extractor, feature_extractor,
image_encoder,
requires_safety_checker, requires_safety_checker,
) )
self.register_modules( self.register_modules(
...@@ -90,6 +97,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -90,6 +97,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
scheduler=scheduler, scheduler=scheduler,
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
image_encoder=image_encoder,
) )
@torch.no_grad() @torch.no_grad()
...@@ -110,17 +118,40 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -110,17 +118,40 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
rp_args: Dict[str, str] = None, rp_args: Dict[str, str] = None,
): ):
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
if negative_prompt is None: if negative_prompt is None:
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt) negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
device = self._execution_device device = self._execution_device
regions = 0 regions = 0
self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
self.power = int(rp_args["power"]) if "power" in rp_args else 1 self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if isinstance(prompt, list) else [prompt] prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt] n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts) self.batch = batch = num_images_per_prompt * len(prompts)
if use_base:
bases = prompts.copy()
n_bases = n_prompts.copy()
for i, prompt in enumerate(prompts):
parts = prompt.split(KBASE)
if len(parts) == 2:
bases[i], prompts[i] = parts
elif len(parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
for i, prompt in enumerate(n_prompts):
n_parts = prompt.split(KBASE)
if len(n_parts) == 2:
n_bases[i], n_prompts[i] = n_parts
elif len(n_parts) > 2:
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt) all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt) all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
...@@ -137,8 +168,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -137,8 +168,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
conds = getcompelembs(all_prompts_cn) conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts) base_embs = getcompelembs(all_bases_cn) if use_base else None
n_embs = getcompelembs(n_prompts) base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
embs = getcompelembs(prompts) if not use_base else base_embs
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
if use_base and self.base_ratio > 0:
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
prompt = negative_prompt = None prompt = negative_prompt = None
else: else:
conds = self.encode_prompt(prompts, device, 1, True)[0] conds = self.encode_prompt(prompts, device, 1, True)[0]
...@@ -147,6 +186,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -147,6 +186,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
if equal if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0] else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
) )
if use_base and self.base_ratio > 0:
base_embs = self.encode_prompt(bases, device, 1, True)[0]
base_n_embs = (
self.encode_prompt(n_bases, device, 1, True)[0]
if equal
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
)
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
embs = n_embs = None embs = n_embs = None
if not active: if not active:
...@@ -225,8 +276,6 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -225,8 +276,6 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -247,16 +296,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -247,16 +296,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -283,7 +331,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline): ...@@ -283,7 +331,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch): ...@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
add = "" add = ""
if KCOMM in prompt: if KCOMM in prompt:
add, prompt = prompt.split(KCOMM) add, prompt = prompt.split(KCOMM)
add = add + " " add = add.strip() + " "
prompts = prompt.split(KBRK) prompts = [p.strip() for p in prompt.split(KBRK)]
out_p.append([add + p for p in prompts]) out_p.append([add + p for i, p in enumerate(prompts)])
out = [None] * batch * len(out_p[0]) * len(out_p) out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions for r, pr in enumerate(prs): # prompts for regions
...@@ -449,7 +497,6 @@ def make_cells(ratios): ...@@ -449,7 +497,6 @@ def make_cells(ratios):
add = [] add = []
startend(add, inratios[1:]) startend(add, inratios[1:])
icells.append(add) icells.append(add)
return ocells, icells, sum(len(cell) for cell in icells) return ocells, icells, sum(len(cell) for cell in icells)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment