Unverified Commit 7e164d98 authored by Paakhhi's avatar Paakhhi Committed by GitHub
Browse files

Fix diffusers import prompt2prompt (#6927)



* Bugfix: correct import for diffusers

* Fix: Prompt2Prompt example

* Format style

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent e6d1728e
......@@ -2287,9 +2287,9 @@ Here's a full example for `ReplaceEdit``:
import torch
import numpy as np
import matplotlib.pyplot as plt
from diffusers.pipelines import Prompt2PromptPipeline
from diffusers import DiffusionPipeline
pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda")
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
prompts = ["A turtle playing with a ball",
"A monkey playing with a ball"]
......
......@@ -21,8 +21,11 @@ import numpy as np
import torch
import torch.nn.functional as F
from ...src.diffusers.models.attention import Attention
from ...src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput
from diffusers.models.attention import Attention
from diffusers.pipelines.stable_diffusion import (
StableDiffusionPipeline,
StableDiffusionPipelineOutput,
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
......@@ -165,7 +168,11 @@ class Prompt2PromptPipeline(StableDiffusionPipeline):
"""
self.controller = create_controller(
prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=self.device
prompt,
cross_attention_kwargs,
num_inference_steps,
tokenizer=self.tokenizer,
device=self.device,
)
self.register_attention_control(self.controller) # add attention controller
......@@ -287,7 +294,7 @@ class Prompt2PromptPipeline(StableDiffusionPipeline):
attn_procs = {}
cross_att_count = 0
for name in self.unet.attn_processors.keys():
None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
(None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim)
if name.startswith("mid_block"):
self.unet.config.block_out_channels[-1]
place_in_unet = "mid"
......@@ -314,7 +321,13 @@ class P2PCrossAttnProcessor:
self.controller = controller
self.place_in_unet = place_in_unet
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
......@@ -346,7 +359,11 @@ class P2PCrossAttnProcessor:
def create_controller(
prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device
prompts: List[str],
cross_attention_kwargs: Dict,
num_inference_steps: int,
tokenizer,
device,
) -> AttentionControl:
edit_type = cross_attention_kwargs.get("edit_type", None)
local_blend_words = cross_attention_kwargs.get("local_blend_words", None)
......@@ -358,27 +375,49 @@ def create_controller(
# only replace
if edit_type == "replace" and local_blend_words is None:
return AttentionReplace(
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
)
# replace + localblend
if edit_type == "replace" and local_blend_words is not None:
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
return AttentionReplace(
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
lb,
tokenizer=tokenizer,
device=device,
)
# only refine
if edit_type == "refine" and local_blend_words is None:
return AttentionRefine(
prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
)
# refine + localblend
if edit_type == "refine" and local_blend_words is not None:
lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device)
return AttentionRefine(
prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
lb,
tokenizer=tokenizer,
device=device,
)
# reweight
......@@ -447,7 +486,14 @@ class EmptyControl(AttentionControl):
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
return {
"down_cross": [],
"mid_cross": [],
"up_cross": [],
"down_self": [],
"mid_self": [],
"up_self": [],
}
def forward(self, attn, is_cross: bool, place_in_unet: str):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
......@@ -497,7 +543,13 @@ class LocalBlend:
return x_t
def __init__(
self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, max_num_words=77
self,
prompts: List[str],
words: [List[List[str]]],
tokenizer,
device,
threshold=0.3,
max_num_words=77,
):
self.max_num_words = 77
......@@ -588,7 +640,13 @@ class AttentionReplace(AttentionControlEdit):
device=None,
):
super(AttentionReplace, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
)
self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
......@@ -610,7 +668,13 @@ class AttentionRefine(AttentionControlEdit):
device=None,
):
super(AttentionRefine, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
)
self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
......@@ -637,7 +701,13 @@ class AttentionReweight(AttentionControlEdit):
device=None,
):
super(AttentionReweight, self).__init__(
prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
)
self.equalizer = equalizer.to(self.device)
self.prev_controller = controller
......@@ -645,7 +715,10 @@ class AttentionReweight(AttentionControlEdit):
### util functions for all Edits
def update_alpha_time_word(
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
alpha,
bounds: Union[float, Tuple[float, float]],
prompt_ind: int,
word_inds: Optional[torch.Tensor] = None,
):
if isinstance(bounds, float):
bounds = 0, bounds
......@@ -659,7 +732,11 @@ def update_alpha_time_word(
def get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
prompts,
num_steps,
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
tokenizer,
max_num_words=77,
):
if not isinstance(cross_replace_steps, dict):
cross_replace_steps = {"default_": cross_replace_steps}
......@@ -750,7 +827,10 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77):
### util functions for ReweightEdit
def get_equalizer(
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
text: str,
word_select: Union[int, Tuple[int, ...]],
values: Union[List[float], Tuple[float, ...]],
tokenizer,
):
if isinstance(word_select, (int, str)):
word_select = (word_select,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment