Unverified Commit c2217142 authored by Hyoungwon Cho's avatar Hyoungwon Cho Committed by GitHub
Browse files

Modification on the PAG community pipeline (re) (#7876)



* edited_pag_implementation

* update

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent 8edaf3b7
# Implementation of StableDiffusionPAGPipeline # Implementation of StableDiffusionPipeline with PAG
# https://ku-cvlab.github.io/Perturbed-Attention-Guidance
import inspect import inspect
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
...@@ -134,8 +135,8 @@ class PAGIdentitySelfAttnProcessor: ...@@ -134,8 +135,8 @@ class PAGIdentitySelfAttnProcessor:
value = attn.to_v(hidden_states_ptb) value = attn.to_v(hidden_states_ptb)
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
# hidden_states_ptb = value hidden_states_ptb = value
hidden_states_ptb = hidden_states_ptb.to(query.dtype) hidden_states_ptb = hidden_states_ptb.to(query.dtype)
...@@ -1045,7 +1046,7 @@ class StableDiffusionPAGPipeline( ...@@ -1045,7 +1046,7 @@ class StableDiffusionPAGPipeline(
return self._pag_scale return self._pag_scale
@property @property
def do_adversarial_guidance(self): def do_perturbed_attention_guidance(self):
return self._pag_scale > 0 return self._pag_scale > 0
@property @property
...@@ -1056,14 +1057,6 @@ class StableDiffusionPAGPipeline( ...@@ -1056,14 +1057,6 @@ class StableDiffusionPAGPipeline(
def do_pag_adaptive_scaling(self): def do_pag_adaptive_scaling(self):
return self._pag_adaptive_scaling > 0 return self._pag_adaptive_scaling > 0
@property
def pag_drop_rate(self):
return self._pag_drop_rate
@property
def pag_applied_layers(self):
return self._pag_applied_layers
@property @property
def pag_applied_layers_index(self): def pag_applied_layers_index(self):
return self._pag_applied_layers_index return self._pag_applied_layers_index
...@@ -1080,8 +1073,6 @@ class StableDiffusionPAGPipeline( ...@@ -1080,8 +1073,6 @@ class StableDiffusionPAGPipeline(
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
pag_scale: float = 0.0, pag_scale: float = 0.0,
pag_adaptive_scaling: float = 0.0, pag_adaptive_scaling: float = 0.0,
pag_drop_rate: float = 0.5,
pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up']
pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0'] pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1, num_images_per_prompt: Optional[int] = 1,
...@@ -1221,8 +1212,6 @@ class StableDiffusionPAGPipeline( ...@@ -1221,8 +1212,6 @@ class StableDiffusionPAGPipeline(
self._pag_scale = pag_scale self._pag_scale = pag_scale
self._pag_adaptive_scaling = pag_adaptive_scaling self._pag_adaptive_scaling = pag_adaptive_scaling
self._pag_drop_rate = pag_drop_rate
self._pag_applied_layers = pag_applied_layers
self._pag_applied_layers_index = pag_applied_layers_index self._pag_applied_layers_index = pag_applied_layers_index
# 2. Define call parameters # 2. Define call parameters
...@@ -1257,13 +1246,13 @@ class StableDiffusionPAGPipeline( ...@@ -1257,13 +1246,13 @@ class StableDiffusionPAGPipeline(
# to avoid doing two forward passes # to avoid doing two forward passes
# cfg # cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance: if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# pag # pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
# both # both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance: elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None: if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
...@@ -1306,7 +1295,7 @@ class StableDiffusionPAGPipeline( ...@@ -1306,7 +1295,7 @@ class StableDiffusionPAGPipeline(
).to(device=device, dtype=latents.dtype) ).to(device=device, dtype=latents.dtype)
# 7. Denoising loop # 7. Denoising loop
if self.do_adversarial_guidance: if self.do_perturbed_attention_guidance:
down_layers = [] down_layers = []
mid_layers = [] mid_layers = []
up_layers = [] up_layers = []
...@@ -1322,6 +1311,29 @@ class StableDiffusionPAGPipeline( ...@@ -1322,6 +1311,29 @@ class StableDiffusionPAGPipeline(
else: else:
raise ValueError(f"Invalid layer type: {layer_type}") raise ValueError(f"Invalid layer type: {layer_type}")
# change attention layer in UNet if use PAG
if self.do_perturbed_attention_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
...@@ -1330,41 +1342,18 @@ class StableDiffusionPAGPipeline( ...@@ -1330,41 +1342,18 @@ class StableDiffusionPAGPipeline(
continue continue
# cfg # cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance: if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2) latent_model_input = torch.cat([latents] * 2)
# pag # pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2) latent_model_input = torch.cat([latents] * 2)
# both # both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance: elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 3) latent_model_input = torch.cat([latents] * 3)
# no # no
else: else:
latent_model_input = latents latent_model_input = latents
# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual # predict the noise residual
...@@ -1381,14 +1370,14 @@ class StableDiffusionPAGPipeline( ...@@ -1381,14 +1370,14 @@ class StableDiffusionPAGPipeline(
# perform guidance # perform guidance
# cfg # cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance: if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
delta = noise_pred_text - noise_pred_uncond delta = noise_pred_text - noise_pred_uncond
noise_pred = noise_pred_uncond + self.guidance_scale * delta noise_pred = noise_pred_uncond + self.guidance_scale * delta
# pag # pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance: elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2) noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
signal_scale = self.pag_scale signal_scale = self.pag_scale
...@@ -1400,7 +1389,7 @@ class StableDiffusionPAGPipeline( ...@@ -1400,7 +1389,7 @@ class StableDiffusionPAGPipeline(
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb) noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
# both # both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance: elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3) noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
signal_scale = self.pag_scale signal_scale = self.pag_scale
...@@ -1458,11 +1447,8 @@ class StableDiffusionPAGPipeline( ...@@ -1458,11 +1447,8 @@ class StableDiffusionPAGPipeline(
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
# change attention layer in UNet if use PAG # change attention layer in UNet if use PAG
if self.do_adversarial_guidance: if self.do_perturbed_attention_guidance:
drop_layers = self.pag_applied_layers_index drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers: for drop_layer in drop_layers:
try: try:
...@@ -1479,4 +1465,7 @@ class StableDiffusionPAGPipeline( ...@@ -1479,4 +1465,7 @@ class StableDiffusionPAGPipeline(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers." f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
) )
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment