Unverified Commit 7b98c4cc authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] Add PAG support for PixArtSigma (#8921)

* feat: add pixart sigma pag.

* inits.

* fixes

* fix

* remove print.

* copy paste methods to the pixart pag mixin

* fix-copies

* add documentation.

* add tests.

* remove correction file.

* remove pag_applied_layers

* empty
parent 27637a54
...@@ -54,3 +54,9 @@ The abstract from the paper is: ...@@ -54,3 +54,9 @@ The abstract from the paper is:
[[autodoc]] StableDiffusionXLControlNetPAGPipeline [[autodoc]] StableDiffusionXLControlNetPAGPipeline
- all - all
- __call__ - __call__
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
- __call__
\ No newline at end of file
...@@ -22,7 +22,7 @@ This guide will show you how to use PAG for various tasks and use cases. ...@@ -22,7 +22,7 @@ This guide will show you how to use PAG for various tasks and use cases.
You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument. You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.
> [!TIP] > [!TIP]
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline! > 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines and [`PixArtSigmaPAGPipeline`]. But feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
<hfoptions id="tasks"> <hfoptions id="tasks">
<hfoption id="Text-to-image"> <hfoption id="Text-to-image">
......
...@@ -295,6 +295,7 @@ else: ...@@ -295,6 +295,7 @@ else:
"PaintByExamplePipeline", "PaintByExamplePipeline",
"PIAPipeline", "PIAPipeline",
"PixArtAlphaPipeline", "PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline", "PixArtSigmaPipeline",
"SemanticStableDiffusionPipeline", "SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline", "ShapEImg2ImgPipeline",
...@@ -717,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -717,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PaintByExamplePipeline, PaintByExamplePipeline,
PIAPipeline, PIAPipeline,
PixArtAlphaPipeline, PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline, PixArtSigmaPipeline,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline, ShapEImg2ImgPipeline,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -19,6 +19,7 @@ from torch import nn ...@@ -19,6 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -186,6 +187,66 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -186,6 +187,66 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -151,6 +151,7 @@ else: ...@@ -151,6 +151,7 @@ else:
"StableDiffusionXLPAGInpaintPipeline", "StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLControlNetPAGPipeline", "StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGImg2ImgPipeline",
"PixArtSigmaPAGPipeline",
] ]
) )
_import_structure["controlnet_xs"].extend( _import_structure["controlnet_xs"].extend(
...@@ -531,6 +532,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -531,6 +532,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .musicldm import MusicLDMPipeline from .musicldm import MusicLDMPipeline
from .pag import ( from .pag import (
AnimateDiffPAGPipeline, AnimateDiffPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline, StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLControlNetPAGPipeline,
......
...@@ -50,6 +50,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline ...@@ -50,6 +50,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pag import ( from .pag import (
PixArtSigmaPAGPipeline,
StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline, StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLControlNetPAGPipeline,
...@@ -98,6 +99,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -98,6 +99,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline), ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGPipeline),
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline), ("auraflow", AuraFlowPipeline),
("kolors", KolorsPipeline), ("kolors", KolorsPipeline),
("flux", FluxPipeline), ("flux", FluxPipeline),
......
...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"] _import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
...@@ -40,6 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -40,6 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
......
...@@ -275,3 +275,185 @@ class PAGMixin: ...@@ -275,3 +275,185 @@ class PAGMixin:
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0): if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc processors[name] = proc
return processors return processors
class PixArtPAGMixin:
@staticmethod
def _check_input_pag_applied_layer(layer):
r"""
Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
"""
# Check if the layer index is valid (should be int or str of int)
if isinstance(layer, int):
return # Valid layer index
if isinstance(layer, str):
if layer.isdigit():
return # Valid layer index
# If it is not a valid layer index, raise a ValueError
raise ValueError(f"Pag layer should only contain block index. Accept number string like '3', got {layer}.")
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
if do_classifier_free_guidance:
pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0()
else:
pag_attn_proc = PAGIdentitySelfAttnProcessor2_0()
def is_self_attn(module_name):
r"""
Check if the module is self-attention module based on its name.
"""
return (
"attn1" in module_name and len(module_name.split(".")) == 3
) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...
def get_block_index(module_name):
r"""
Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
mid_block) and index is ommited from the name, it will be "block_0".
"""
# transformer_blocks.23.attn -> "23"
return module_name.split(".")[1]
for pag_layer_input in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the transformer model
target_modules = []
block_index = str(pag_layer_input)
for name, module in self.transformer.named_modules():
if is_self_attn(name) and get_block_index(name) == block_index:
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find pag layer to set attention processor for: {pag_layer_input}")
for module in target_modules:
module.processor = pag_attn_proc
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
def set_pag_applied_layers(self, pag_applied_layers):
r"""
set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
"""
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
for pag_layer in pag_applied_layers:
self._check_input_pag_applied_layer(pag_layer)
self.pag_applied_layers = pag_applied_layers
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
def _get_pag_scale(self, t):
r"""
Get the scale factor for the perturbed attention guidance at timestep `t`.
"""
if self.do_pag_adaptive_scaling:
signal_scale = self.pag_scale - self.pag_adaptive_scale * (1000 - t)
if signal_scale < 0:
signal_scale = 0
return signal_scale
else:
return self.pag_scale
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
r"""
Apply perturbed attention guidance to the noise prediction.
Args:
noise_pred (torch.Tensor): The noise prediction tensor.
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step.
Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ guidance_scale * (noise_pred_text - noise_pred_uncond)
+ pag_scale * (noise_pred_text - noise_pred_perturb)
)
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
return noise_pred
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
"""
Prepares the perturbed attention guidance for the PAG model.
Args:
cond (torch.Tensor): The conditional input tensor.
uncond (torch.Tensor): The unconditional input tensor.
do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
Returns:
torch.Tensor: The prepared perturbed attention guidance tensor.
"""
cond = torch.cat([cond] * 2, dim=0)
if do_classifier_free_guidance:
cond = torch.cat([uncond, cond], dim=0)
return cond
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
def pag_scale(self):
"""
Get the scale factor for the perturbed attention guidance.
"""
return self._pag_scale
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
def pag_adaptive_scale(self):
"""
Get the adaptive scale factor for the perturbed attention guidance.
"""
return self._pag_adaptive_scale
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
def do_pag_adaptive_scaling(self):
"""
Check if the adaptive scaling is enabled for the perturbed attention guidance.
"""
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
def do_perturbed_attention_guidance(self):
"""
Check if the perturbed attention guidance is enabled.
"""
return self._pag_scale > 0 and len(self.pag_applied_layers) > 0
@property
# Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
def pag_attn_processors(self):
r"""
Returns:
`dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
with the key as the name of the layer.
"""
processors = {}
for name, proc in self.transformer.attn_processors.items():
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc
return processors
This diff is collapsed.
...@@ -977,6 +977,21 @@ class PixArtAlphaPipeline(metaclass=DummyObject): ...@@ -977,6 +977,21 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class PixArtSigmaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class PixArtSigmaPipeline(metaclass=DummyObject): class PixArtSigmaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
import diffusers
from diffusers import (
AutoencoderKL,
DDIMScheduler,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PixArtTransformer2DModel,
)
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
CaptureLogger,
enable_full_determinism,
print_tensor_test,
torch_device,
)
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference, to_np
enable_full_determinism()
class PixArtSigmaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PixArtSigmaPAGPipeline
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
params = set(params)
params.remove("cross_attention_kwargs")
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params
def get_dummy_components(self):
torch.manual_seed(0)
transformer = PixArtTransformer2DModel(
sample_size=8,
num_layers=2,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
caption_channels=32,
in_channels=4,
cross_attention_dim=24,
out_channels=8,
attention_bias=True,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_single",
norm_elementwise_affine=False,
norm_eps=1e-6,
)
torch.manual_seed(0)
vae = AutoencoderKL()
scheduler = DDIMScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"pag_scale": 3.0,
"use_resolution_binning": False,
"output_type": "np",
}
return inputs
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline (expect same output when pag is disabled)
pipe = PixArtSigmaPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
components["pag_applied_layers"] = [1]
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
# pag enabled
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
def test_pag_applied_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# "attn1" should apply to all self-attention layers.
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn1" in k]
pag_layers = [0, 1]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
def test_pag_inference(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe_pag(**inputs).images
image_slice = image[0, -3:, -3:, -1]
print_tensor_test(image_slice)
assert image.shape == (
1,
8,
8,
3,
), f"the shape of the output image should be (1, 8, 8, 3) but got {image.shape}"
expected_slice = np.array([0.6499, 0.3250, 0.3572, 0.6780, 0.4453, 0.4582, 0.2770, 0.5168, 0.4594])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
# Copied from tests.pipelines.pixart_sigma.test_pixart.PixArtSigmaPipelineFastTests.test_save_load_optional_components
def test_save_load_optional_components(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs["prompt"]
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(prompt)
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
# set all optional components to None
for optional_component in pipe._optional_components:
setattr(pipe, optional_component, None)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1])
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for optional_component in pipe._optional_components:
self.assertTrue(
getattr(pipe_loaded, optional_component) is None,
f"`{optional_component}` did not stay set to None after loading.",
)
inputs = self.get_dummy_inputs(torch_device)
generator = inputs["generator"]
num_inference_steps = inputs["num_inference_steps"]
output_type = inputs["output_type"]
# inputs with prompt converted to embeddings
inputs = {
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"negative_prompt": None,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_prompt_attention_mask": negative_prompt_attention_mask,
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
"use_resolution_binning": False,
}
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, 1e-4)
# Because the PAG PixArt Sigma has `pag_applied_layers`.
# Also, we shouldn't be doing `set_default_attn_processor()` after loading
# the pipeline with `pag_applied_layers`.
def test_save_load_local(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, pag_applied_layers=[1])
for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components:
assert name in str(cap_logger)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(max_diff, expected_max_difference)
# We shouldn't be setting `set_default_attn_processor` here.
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
if test_mean_pixel_difference:
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
# Because we have `pag_applied_layers` we cannot direcly apply
# `set_default_attn_processor`
def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
if expected_slice is None:
output = pipe(**self.get_dummy_inputs(generator_device))[0]
else:
output = expected_slice
output_tuple = pipe(**self.get_dummy_inputs(generator_device), return_dict=False)[0]
if expected_slice is None:
max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
else:
if output_tuple.ndim != 5:
max_diff = np.abs(to_np(output) - to_np(output_tuple)[0, -3:, -3:, -1].flatten()).max()
else:
max_diff = np.abs(to_np(output) - to_np(output_tuple)[0, -3:, -3:, -1, -1].flatten()).max()
self.assertLess(max_diff, expected_max_difference)
# Same reason as above
def test_inference_batch_single_identical(
self,
batch_size=2,
expected_max_diff=1e-4,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batched_inputs.update(inputs)
for name in self.batch_params:
if name not in inputs:
continue
value = inputs[name]
if name == "prompt":
len_prompt = len(value)
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
batched_inputs[name][-1] = 100 * "very long"
else:
batched_inputs[name] = batch_size * [value]
if "generator" in inputs:
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
if "batch_size" in inputs:
batched_inputs["batch_size"] = batch_size
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output = pipe(**inputs)
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
# Because we're passing `pag_applied_layers` (type of List) in the components as well.
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float, list))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment