Unverified Commit 05b706c0 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

PAG variant for AnimateDiff (#8789)

* add animatediff pag pipeline

* remove unnecessary print

* make fix-copies

* fix ip-adapter bug

* update docs

* add fast tests and fix bugs

* update

* update

* address review comments

* update ip adapter single test expected slice

* implement test_from_pipe_consistent_config; fix expected slice values

* LoraLoaderMixin->StableDiffusionLoraLoaderMixin; add latest freeinit test
parent ea1b4ea7
......@@ -20,6 +20,11 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
- all
- __call__
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
......
......@@ -233,6 +233,7 @@ else:
"AmusedInpaintPipeline",
"AmusedPipeline",
"AnimateDiffControlNetPipeline",
"AnimateDiffPAGPipeline",
"AnimateDiffPipeline",
"AnimateDiffSDXLPipeline",
"AnimateDiffSparseControlNetPipeline",
......@@ -654,6 +655,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AmusedInpaintPipeline,
AmusedPipeline,
AnimateDiffControlNetPipeline,
AnimateDiffPAGPipeline,
AnimateDiffPipeline,
AnimateDiffSDXLPipeline,
AnimateDiffSparseControlNetPipeline,
......
......@@ -143,6 +143,7 @@ else:
)
_import_structure["pag"].extend(
[
"AnimateDiffPAGPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline",
......@@ -527,6 +528,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,
StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline,
......
......@@ -25,6 +25,7 @@ else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
......@@ -40,6 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
......
......@@ -33,7 +33,7 @@ class PAGMixin:
Check if each layer input in `applied_pag_layers` is valid. It should be either one of these 3 formats:
"{block_type}", "{block_type}.{block_index}", or "{block_type}.{block_index}.{attention_index}". `block_type`
can be "down", "mid", "up". `block_index` should be in the format of "block_{i}". `attention_index` should be
in the format of "attentions_{j}".
in the format of "attentions_{j}". `motion_modules_index` should be in the format of "motion_modules_{j}"
"""
layer_splits = layer.split(".")
......@@ -52,8 +52,11 @@ class PAGMixin:
raise ValueError(f"Invalid block_index in pag layer: {layer}. Should start with 'block_'")
if len(layer_splits) == 3:
if not layer_splits[2].startswith("attentions_"):
raise ValueError(f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_'")
layer_2 = layer_splits[2]
if not layer_2.startswith("attentions_") and not layer_2.startswith("motion_modules_"):
raise ValueError(
f"Invalid attention_index in pag layer: {layer}. Should start with 'attentions_' or 'motion_modules_'"
)
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
......@@ -72,33 +75,46 @@ class PAGMixin:
def get_block_type(module_name):
r"""
Get the block type from the module name. can be "down", "mid", "up".
Get the block type from the module name. Can be "down", "mid", "up".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "down"
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "down"
return module_name.split(".")[0].split("_")[0]
def get_block_index(module_name):
r"""
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
Get the block index from the module name. Can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "block_1"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "block_0"
if "attentions" in module_name.split(".")[1]:
module_name_splits = module_name.split(".")
block_index = module_name_splits[1]
if "attentions" in block_index or "motion_modules" in block_index:
return "block_0"
else:
return f"block_{module_name.split('.')[1]}"
return f"block_{block_index}"
def get_attn_index(module_name):
r"""
Get the attention index from the module name. can be "attentions_0", "attentions_1", ...
Get the attention index from the module name. Can be "attentions_0", "attentions_1", "motion_modules_0",
"motion_modules_1", ...
"""
# down_blocks.1.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
# mid_block.attentions.0.transformer_blocks.0.attn1 -> "attentions_0"
if "attentions" in module_name.split(".")[2]:
return f"attentions_{module_name.split('.')[3]}"
elif "attentions" in module_name.split(".")[1]:
return f"attentions_{module_name.split('.')[2]}"
# down_blocks.1.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
# mid_block.motion_modules.0.transformer_blocks.0.attn1 -> "motion_modules_0"
module_name_split = module_name.split(".")
mid_name = module_name_split[1]
down_name = module_name_split[2]
if "attentions" in down_name:
return f"attentions_{module_name_split[3]}"
if "attentions" in mid_name:
return f"attentions_{module_name_split[2]}"
if "motion_modules" in down_name:
return f"motion_modules_{module_name_split[3]}"
if "motion_modules" in mid_name:
return f"motion_modules_{module_name_split[2]}"
for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
......@@ -114,7 +130,7 @@ class PAGMixin:
target_modules.append(module)
elif len(pag_layer_input_splits) == 2:
# when the layer inpput contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
# when the layer input contains both block_type and block_index. e.g. "down.block_1", "mid.block_0"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
for name, module in self.unet.named_modules():
......@@ -126,7 +142,8 @@ class PAGMixin:
target_modules.append(module)
elif len(pag_layer_input_splits) == 3:
# when the layer input contains block_type, block_index and attention_index. e.g. "down.blocks_1.attentions_1"
# when the layer input contains block_type, block_index and attention_index.
# e.g. "down.block_1.attentions_1" or "down.block_1.motion_modules_1"
block_type = pag_layer_input_splits[0]
block_index = pag_layer_input_splits[1]
attn_index = pag_layer_input_splits[2]
......
This diff is collapsed.
......@@ -92,6 +92,21 @@ class AnimateDiffControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AnimateDiffPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
import inspect
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AnimateDiffPAGPipeline,
AnimateDiffPipeline,
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
LCMScheduler,
MotionAdapter,
StableDiffusionPipeline,
UNet2DConditionModel,
UNetMotionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineFromPipeTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
)
def to_np(tensor):
if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy()
return tensor
class AnimateDiffPAGPipelineFastTests(
IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffPAGPipeline
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
def get_dummy_components(self):
cross_attention_dim = 8
block_out_channels = (8, 8)
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=block_out_channels,
layers_per_block=2,
sample_size=8,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=block_out_channels,
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
motion_adapter = MotionAdapter(
block_out_channels=block_out_channels,
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 7.5,
"pag_scale": 3.0,
"output_type": "pt",
}
return inputs
def test_from_pipe_consistent_config(self):
assert self.original_pipeline_class == StableDiffusionPipeline
original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
original_kwargs = {"requires_safety_checker": False}
# create original_pipeline_class(sd)
pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
# original_pipeline_class(sd) -> pipeline_class
pipe_components = self.get_dummy_components()
pipe_additional_components = {}
for name, component in pipe_components.items():
if name not in pipe_original.components:
pipe_additional_components[name] = component
pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
# pipeline_class -> original_pipeline_class(sd)
original_pipe_additional_components = {}
for name, component in pipe_original.components.items():
if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
original_pipe_additional_components[name] = component
pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
# compare the config
original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
assert original_config_2 == original_config
def test_motion_unet_loading(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
assert isinstance(pipe.unet, UNetMotionModel)
@unittest.skip("Attention slicing is not enabled in this pipeline")
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[
0.5068,
0.5294,
0.4926,
0.4810,
0.4188,
0.5935,
0.5295,
0.3947,
0.5300,
0.4706,
0.3950,
0.4737,
0.4072,
0.3227,
0.5481,
0.4864,
0.4518,
0.5315,
0.5979,
0.5374,
0.3503,
0.5275,
0.6067,
0.4914,
0.5440,
0.4775,
0.5538,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
if torch_device == "cpu":
expected_slice = np.array([0.5295, 0.3947, 0.5300, 0.4864, 0.4518, 0.5315, 0.5440, 0.4775, 0.5538])
return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
def test_to_device(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to("cuda")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
self.assertTrue(all(device == "cuda" for device in model_devices))
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_prompt_embeds(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("prompt")
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
pipe(**inputs)
def test_free_init(self):
components = self.get_dummy_components()
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
pipe.enable_free_init(
num_iters=2,
use_fast_sampling=True,
method="butterworth",
order=4,
spatial_stop_frequency=0.25,
temporal_stop_frequency=0.25,
)
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
pipe.disable_free_init()
inputs_disable_free_init = self.get_dummy_inputs(torch_device)
frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
self.assertGreater(
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
)
self.assertLess(
max_diff_disabled,
1e-3,
"Disabling of FreeInit should lead to results similar to the default pipeline results",
)
def test_free_init_with_schedulers(self):
components = self.get_dummy_components()
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
schedulers_to_test = [
DPMSolverMultistepScheduler.from_config(
components["scheduler"].config,
timestep_spacing="linspace",
beta_schedule="linear",
algorithm_type="dpmsolver++",
steps_offset=1,
clip_sample=False,
),
LCMScheduler.from_config(
components["scheduler"].config,
timestep_spacing="linspace",
beta_schedule="linear",
steps_offset=1,
clip_sample=False,
),
]
components.pop("scheduler")
for scheduler in schedulers_to_test:
components["scheduler"] = scheduler
pipe: AnimateDiffPAGPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
inputs = self.get_dummy_inputs(torch_device)
frames_enable_free_init = pipe(**inputs).frames[0]
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
self.assertGreater(
sum_enabled,
1e1,
"Enabling of FreeInit should lead to results different from the default pipeline results",
)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_without_offload = pipe(**inputs).frames[0]
output_without_offload = (
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
)
pipe.enable_xformers_memory_efficient_attention()
inputs = self.get_dummy_inputs(torch_device)
output_with_offload = pipe(**inputs).frames[0]
output_with_offload = (
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
)
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
def test_vae_slicing(self):
return super().test_vae_slicing(image_count=2)
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline (expect same output when pag is disabled)
components.pop("pag_applied_layers", None)
pipe_sd = AnimateDiffPipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
components = self.get_dummy_components()
# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).frames[0, -3:, -3:, -1]
# pag enabled
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
out_pag_enabled = pipe_pag(**inputs).frames[0, -3:, -3:, -1]
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
def test_pag_applied_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline
components.pop("pag_applied_layers", None)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k]
original_attn_procs = pipe.unet.attn_processors
pag_layers = [
"down",
"mid",
"up",
]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
# pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e.
# mid_block.motion_modules.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor
all_self_attn_mid_layers = [
"mid_block.motion_modules.0.transformer_blocks.0.attn1.processor",
"mid_block.attentions.0.transformer_blocks.0.attn1.processor",
]
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
# pag_applied_layers = "down" should apply to all self-attention layers in down_blocks
# down_blocks.1.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor
# down_blocks.1.(attentions|motion_modules).0.transformer_blocks.1.attn1.processor
# down_blocks.1.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 6
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert (len(pipe.pag_attn_processors)) == 4
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2
pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.motion_modules_2"]
with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment