Unverified Commit b7058d14 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

PAG variant for HunyuanDiT, PAG refactor (#8936)



* copy hunyuandit pipeline

* pag variant of hunyuan dit

* add tests

* update docs

* make style

* make fix-copies

* Update src/diffusers/pipelines/pag/pag_utils.py

* remove incorrect copied from

* remove pag hunyuan attn procs to resolve conflicts

* add pag attn procs again

* new implementation for pag_utils

* revert pag changes

* add pag refactor back; update pixart sigma

* update pixart pag tests

* apply suggestions from review

Co-Authored-By: yixu310@gmail.com

* make style

* update docs, fix tests

* fix tests

* fix test_components_function since list not accepted as valid __init__ param

* apply patch to fix broken tests
Co-Authored-By: default avatarSayak Paul <spsayakpaul@gmail.com>

* make style

* fix hunyuan tests

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e1d508ae
...@@ -20,11 +20,29 @@ The abstract from the paper is: ...@@ -20,11 +20,29 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.* *Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`
<Tip warning={true}>
Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.
</Tip>
## AnimateDiffPAGPipeline ## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline [[autodoc]] AnimateDiffPAGPipeline
- all - all
- __call__ - __call__
## HunyuanDiTPAGPipeline
[[autodoc]] HunyuanDiTPAGPipeline
- all
- __call__
## StableDiffusionPAGPipeline ## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline [[autodoc]] StableDiffusionPAGPipeline
- all - all
......
...@@ -252,6 +252,7 @@ else: ...@@ -252,6 +252,7 @@ else:
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
"FluxPipeline", "FluxPipeline",
"HunyuanDiTControlNetPipeline", "HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
"I2VGenXLPipeline", "I2VGenXLPipeline",
"IFImg2ImgPipeline", "IFImg2ImgPipeline",
...@@ -675,6 +676,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -675,6 +676,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CycleDiffusionPipeline, CycleDiffusionPipeline,
FluxPipeline, FluxPipeline,
HunyuanDiTControlNetPipeline, HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
I2VGenXLPipeline, I2VGenXLPipeline,
IFImg2ImgPipeline, IFImg2ImgPipeline,
......
...@@ -2147,6 +2147,253 @@ class FusedHunyuanAttnProcessor2_0: ...@@ -2147,6 +2147,253 @@ class FusedHunyuanAttnProcessor2_0:
return hidden_states return hidden_states
class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
# 1. Original Path
batch_size, sequence_length, _ = (
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states_org)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states_org
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
# 2. Perturbed Path
if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
hidden_states_ptb = attn.to_v(hidden_states_ptb)
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LuminaAttnProcessor2_0: class LuminaAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
...@@ -3468,4 +3715,6 @@ AttentionProcessor = Union[ ...@@ -3468,4 +3715,6 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor2_0, CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
] ]
...@@ -145,6 +145,7 @@ else: ...@@ -145,6 +145,7 @@ else:
_import_structure["pag"].extend( _import_structure["pag"].extend(
[ [
"AnimateDiffPAGPipeline", "AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusionPAGPipeline", "StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline", "StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline", "StableDiffusionXLPAGPipeline",
...@@ -532,6 +533,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -532,6 +533,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .musicldm import MusicLDMPipeline from .musicldm import MusicLDMPipeline
from .pag import ( from .pag import (
AnimateDiffPAGPipeline, AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline, StableDiffusionPAGPipeline,
......
...@@ -50,6 +50,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline ...@@ -50,6 +50,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import ( from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline, StableDiffusionPAGPipeline,
...@@ -85,6 +86,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -85,6 +86,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-3", StableDiffusion3Pipeline), ("stable-diffusion-3", StableDiffusion3Pipeline),
("if", IFPipeline), ("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline), ("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline),
("kandinsky", KandinskyCombinedPipeline), ("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline), ("kandinsky3", Kandinsky3Pipeline),
......
...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
...@@ -41,6 +42,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -41,6 +42,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
......
This diff is collapsed.
This diff is collapsed.
...@@ -40,7 +40,7 @@ from ..pixart_alpha.pipeline_pixart_alpha import ( ...@@ -40,7 +40,7 @@ from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_1024_BIN, ASPECT_RATIO_1024_BIN,
) )
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pag_utils import PixArtPAGMixin from .pag_utils import PAGMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe = AutoPipelineForText2Image.from_pretrained( >>> pipe = AutoPipelineForText2Image.from_pretrained(
... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", ... "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
... torch_dtype=torch.float16, ... torch_dtype=torch.float16,
... pag_applied_layers=[14], ... pag_applied_layers=["blocks.14"],
... enable_pag=True, ... enable_pag=True,
... ) ... )
>>> pipe = pipe.to("cuda") >>> pipe = pipe.to("cuda")
...@@ -132,7 +132,7 @@ def retrieve_timesteps( ...@@ -132,7 +132,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin): class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
r""" r"""
[PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation [PAG pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag) for text-to-image generation
using PixArt-Sigma. using PixArt-Sigma.
...@@ -164,7 +164,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin): ...@@ -164,7 +164,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PixArtPAGMixin):
vae: AutoencoderKL, vae: AutoencoderKL,
transformer: PixArtTransformer2DModel, transformer: PixArtTransformer2DModel,
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
pag_applied_layers: Union[str, List[str]] = "1", # 1st transformer block pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
): ):
super().__init__() super().__init__()
......
...@@ -129,7 +129,7 @@ class AnimateDiffPAGPipeline( ...@@ -129,7 +129,7 @@ class AnimateDiffPAGPipeline(
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
feature_extractor: CLIPImageProcessor = None, feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None, image_encoder: CLIPVisionModelWithProjection = None,
pag_applied_layers: Union[str, List[str]] = "mid", # ["mid"], ["down.block_1"], ["up.block_0.attentions_0"] pag_applied_layers: Union[str, List[str]] = "mid_block.*attn1", # ["mid"], ["down_blocks.1"]
): ):
super().__init__() super().__init__()
if isinstance(unet, UNet2DConditionModel): if isinstance(unet, UNet2DConditionModel):
......
...@@ -332,6 +332,21 @@ class HunyuanDiTControlNetPipeline(metaclass=DummyObject): ...@@ -332,6 +332,21 @@ class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanDiTPipeline(metaclass=DummyObject): class HunyuanDiTPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -429,7 +429,10 @@ class AnimateDiffPAGPipelineFastTests( ...@@ -429,7 +429,10 @@ class AnimateDiffPAGPipelineFastTests(
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
# pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers # pag_applied_layers = ["mid","up","down"] should apply to all self-attention layers
all_self_attn_layers = [k for k in pipe.unet.attn_processors.keys() if "attn1" in k] # Note that for motion modules in AnimateDiff, both attn1 and attn2 are self-attention
all_self_attn_layers = [
k for k in pipe.unet.attn_processors.keys() if "attn1" in k or ("motion_modules" in k and "attn2" in k)
]
original_attn_procs = pipe.unet.attn_processors original_attn_procs = pipe.unet.attn_processors
pag_layers = [ pag_layers = [
"down", "down",
...@@ -439,12 +442,13 @@ class AnimateDiffPAGPipelineFastTests( ...@@ -439,12 +442,13 @@ class AnimateDiffPAGPipelineFastTests(
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
# pag_applied_layers = ["mid"], or ["mid.block_0"] or ["mid.block_0.motion_modules_0"] should apply to all self-attention layers in mid_block, i.e. # pag_applied_layers = ["mid"], or ["mid_block.0"] should apply to all self-attention layers in mid_block, i.e.
# mid_block.motion_modules.0.transformer_blocks.0.attn1.processor # mid_block.motion_modules.0.transformer_blocks.0.attn1.processor
# mid_block.attentions.0.transformer_blocks.0.attn1.processor # mid_block.attentions.0.transformer_blocks.0.attn1.processor
all_self_attn_mid_layers = [ all_self_attn_mid_layers = [
"mid_block.motion_modules.0.transformer_blocks.0.attn1.processor",
"mid_block.attentions.0.transformer_blocks.0.attn1.processor", "mid_block.attentions.0.transformer_blocks.0.attn1.processor",
"mid_block.motion_modules.0.transformer_blocks.0.attn1.processor",
"mid_block.motion_modules.0.transformer_blocks.0.attn2.processor",
] ]
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid"] pag_layers = ["mid"]
...@@ -452,17 +456,17 @@ class AnimateDiffPAGPipelineFastTests( ...@@ -452,17 +456,17 @@ class AnimateDiffPAGPipelineFastTests(
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"] pag_layers = ["mid_block"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0", "mid.block_0.motion_modules_0"] pag_layers = ["mid_block.(attentions|motion_modules)"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"] pag_layers = ["mid_block.attentions.1"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
...@@ -474,19 +478,19 @@ class AnimateDiffPAGPipelineFastTests( ...@@ -474,19 +478,19 @@ class AnimateDiffPAGPipelineFastTests(
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down"] pag_layers = ["down"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 6 assert len(pipe.pag_attn_processors) == 10
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"] pag_layers = ["down_blocks.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert (len(pipe.pag_attn_processors)) == 4 assert (len(pipe.pag_attn_processors)) == 6
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"] pag_layers = ["blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2 assert len(pipe.pag_attn_processors) == 10
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.motion_modules_2"] pag_layers = ["motion_modules.42"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import (
AutoencoderKL,
DDPMScheduler,
HunyuanDiT2DModel,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class HunyuanDiTPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HunyuanDiTPAGPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
def get_dummy_components(self):
torch.manual_seed(0)
transformer = HunyuanDiT2DModel(
sample_size=16,
num_layers=2,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
in_channels=4,
cross_attention_dim=32,
cross_attention_dim_t5=32,
pooled_projection_dim=16,
hidden_size=24,
activation_fn="gelu-approximate",
)
torch.manual_seed(0)
vae = AutoencoderKL()
scheduler = DDPMScheduler()
text_encoder = BertModel.from_pretrained("hf-internal-testing/tiny-random-BertModel")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BertModel")
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"use_resolution_binning": False,
"pag_scale": 0.0,
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 16, 16, 3))
expected_slice = np.array(
[0.56939435, 0.34541583, 0.35915792, 0.46489206, 0.38775963, 0.45004836, 0.5957267, 0.59481275, 0.33287364]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
pass
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
pass
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
expected_max_diff=1e-3,
)
def test_save_load_optional_components(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs["prompt"]
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = pipe.encode_prompt(prompt, device=torch_device, dtype=torch.float32, text_encoder_index=0)
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
) = pipe.encode_prompt(
prompt,
device=torch_device,
dtype=torch.float32,
text_encoder_index=1,
)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_attention_mask_2": prompt_attention_mask_2,
"negative_prompt_embeds_2": negative_prompt_embeds_2,
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"prompt_embeds_2": prompt_embeds_2,
"prompt_attention_mask_2": prompt_attention_mask_2,
"negative_prompt_embeds_2": negative_prompt_embeds_2,
"negative_prompt_attention_mask_2": negative_prompt_attention_mask_2,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
def test_feed_forward_chunking(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_no_chunking = image[0, -3:, -3:, -1]
pipe.transformer.enable_forward_chunking(chunk_size=1, dim=0)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_chunking = image[0, -3:, -3:, -1]
max_diff = np.abs(to_np(image_slice_no_chunking) - to_np(image_slice_chunking)).max()
self.assertLess(max_diff, 1e-4)
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image = pipe(**inputs)[0]
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_fused = pipe(**inputs)[0]
image_slice_fused = image_fused[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline (expect same output when pag is disabled)
pipe_sd = HunyuanDiTPipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
# pag enabled
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 3.0
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
def test_pag_applied_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
original_attn_procs = pipe.transformer.attn_processors
pag_layers = ["blocks.0", "blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
# blocks.0
block_0_self_attn = ["blocks.0.attn1.processor"]
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0.attn1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.(0|1)"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert (len(pipe.pag_attn_processors)) == 2
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0", r"blocks\.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2
...@@ -127,7 +127,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -127,7 +127,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
out = pipe(**inputs).images[0, -3:, -3:, -1] out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0 # pag disabled with pag_scale=0.0
components["pag_applied_layers"] = [1] components["pag_applied_layers"] = ["blocks.1"]
pipe_pag = self.pipeline_class(**components) pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device) pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None) pipe_pag.set_progress_bar_config(disable=None)
...@@ -158,7 +158,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -158,7 +158,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# "attn1" should apply to all self-attention layers. # "attn1" should apply to all self-attention layers.
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k] all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
pag_layers = [0, 1] pag_layers = ["blocks.0", "blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
...@@ -228,7 +228,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -228,7 +228,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir) pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"])
pipe_loaded.to(torch_device) pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None) pipe_loaded.set_progress_bar_config(disable=None)
...@@ -282,7 +282,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -282,7 +282,7 @@ class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipe.save_pretrained(tmpdir, safe_serialization=False) pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1]) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=["blocks.1"])
for name in pipe_loaded.components.keys(): for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components: if name not in pipe_loaded._optional_components:
......
...@@ -213,18 +213,18 @@ class StableDiffusionPAGPipelineFastTests( ...@@ -213,18 +213,18 @@ class StableDiffusionPAGPipelineFastTests(
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"] pag_layers = ["mid_block"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0"] pag_layers = ["mid_block.attentions.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
# pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"] pag_layers = ["mid_block.attentions.1"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
...@@ -239,17 +239,17 @@ class StableDiffusionPAGPipelineFastTests( ...@@ -239,17 +239,17 @@ class StableDiffusionPAGPipelineFastTests(
assert len(pipe.pag_attn_processors) == 2 assert len(pipe.pag_attn_processors) == 2
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"] pag_layers = ["down_blocks.0"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"] pag_layers = ["down_blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2 assert len(pipe.pag_attn_processors) == 2
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.attentions_1"] pag_layers = ["down_blocks.1.attentions.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 1 assert len(pipe.pag_attn_processors) == 1
......
...@@ -225,18 +225,18 @@ class StableDiffusionXLPAGPipelineFastTests( ...@@ -225,18 +225,18 @@ class StableDiffusionXLPAGPipelineFastTests(
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0"] pag_layers = ["mid_block"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_0"] pag_layers = ["mid_block.attentions.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers) assert set(pipe.pag_attn_processors) == set(all_self_attn_mid_layers)
# pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model # pag_applied_layers = ["mid.block_0.attentions_1"] does not exist in the model
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["mid.block_0.attentions_1"] pag_layers = ["mid_block.attentions.1"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
...@@ -251,17 +251,17 @@ class StableDiffusionXLPAGPipelineFastTests( ...@@ -251,17 +251,17 @@ class StableDiffusionXLPAGPipelineFastTests(
assert len(pipe.pag_attn_processors) == 4 assert len(pipe.pag_attn_processors) == 4
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_0"] pag_layers = ["down_blocks.0"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1"] pag_layers = ["down_blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 4 assert len(pipe.pag_attn_processors) == 4
pipe.unet.set_attn_processor(original_attn_procs.copy()) pipe.unet.set_attn_processor(original_attn_procs.copy())
pag_layers = ["down.block_1.attentions_1"] pag_layers = ["down_blocks.1.attentions.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False) pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2 assert len(pipe.pag_attn_processors) == 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment