Unverified Commit 926daa30 authored by Ahn Donghoon (안동훈 / suno)'s avatar Ahn Donghoon (안동훈 / suno) Committed by GitHub
Browse files

add PAG support for Stable Diffusion 3 (#8861)



add pag sd3


---------
Co-authored-by: default avatarHyoungwonCho <jhw9811@korea.ac.kr>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarcrepejung00 <jaewoojung00@naver.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 325a5de3
...@@ -74,6 +74,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial ...@@ -74,6 +74,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- __call__ - __call__
## StableDiffusion3PAGPipeline
[[autodoc]] StableDiffusion3PAGPipeline
- all
- __call__
## PixArtSigmaPAGPipeline ## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline [[autodoc]] PixArtSigmaPAGPipeline
- all - all
......
...@@ -308,6 +308,7 @@ else: ...@@ -308,6 +308,7 @@ else:
"StableDiffusion3ControlNetPipeline", "StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline", "StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline", "StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline", "StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline", "StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline", "StableDiffusionAttendAndExcitePipeline",
...@@ -741,6 +742,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -741,6 +742,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusion3ControlNetPipeline, StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline, StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline, StableDiffusion3InpaintPipeline,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline, StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline, StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
......
...@@ -1106,6 +1106,326 @@ class JointAttnProcessor2_0: ...@@ -1106,6 +1106,326 @@ class JointAttnProcessor2_0:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class PAGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# store the length of image patch sequences to create a mask that prevents interaction between patches
# similar to making the self-attention map an identity matrix
identity_block_size = hidden_states.shape[1]
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]
# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)
# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)
# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################## perturbed path ##################
batch_size = encoder_hidden_states_ptb.shape[0]
# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)
# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
return hidden_states, encoder_hidden_states
class PAGCFGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
identity_block_size = hidden_states.shape[
1
] # patch embeddings width * height (correspond to self-attention map width or height)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
(
encoder_hidden_states_uncond,
encoder_hidden_states_org,
encoder_hidden_states_ptb,
) = encoder_hidden_states.chunk(3)
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]
# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)
# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)
# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################## perturbed path ##################
batch_size = encoder_hidden_states_ptb.shape[0]
# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)
# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
return hidden_states, encoder_hidden_states
class FusedJointAttnProcessor2_0: class FusedJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections.""" """Attention processor used typically in processing the SD3-like self-attention projections."""
......
...@@ -147,6 +147,7 @@ else: ...@@ -147,6 +147,7 @@ else:
[ [
"AnimateDiffPAGPipeline", "AnimateDiffPAGPipeline",
"HunyuanDiTPAGPipeline", "HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusionPAGPipeline", "StableDiffusionPAGPipeline",
"StableDiffusionControlNetPAGPipeline", "StableDiffusionControlNetPAGPipeline",
"StableDiffusionXLPAGPipeline", "StableDiffusionXLPAGPipeline",
...@@ -540,6 +541,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -540,6 +541,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AnimateDiffPAGPipeline, AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline, StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLControlNetPAGPipeline,
......
...@@ -52,6 +52,7 @@ from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, La ...@@ -52,6 +52,7 @@ from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, La
from .pag import ( from .pag import (
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGPipeline, StableDiffusionControlNetPAGPipeline,
StableDiffusionPAGPipeline, StableDiffusionPAGPipeline,
StableDiffusionXLControlNetPAGPipeline, StableDiffusionXLControlNetPAGPipeline,
...@@ -84,6 +85,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -84,6 +85,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion", StableDiffusionPipeline), ("stable-diffusion", StableDiffusionPipeline),
("stable-diffusion-xl", StableDiffusionXLPipeline), ("stable-diffusion-xl", StableDiffusionXLPipeline),
("stable-diffusion-3", StableDiffusion3Pipeline), ("stable-diffusion-3", StableDiffusion3Pipeline),
("stable-diffusion-3-pag", StableDiffusion3PAGPipeline),
("if", IFPipeline), ("if", IFPipeline),
("hunyuan", HunyuanDiTPipeline), ("hunyuan", HunyuanDiTPipeline),
("hunyuan-pag", HunyuanDiTPAGPipeline), ("hunyuan-pag", HunyuanDiTPAGPipeline),
......
...@@ -27,6 +27,7 @@ else: ...@@ -27,6 +27,7 @@ else:
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"] _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"] _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"] _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
...@@ -45,6 +46,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -45,6 +46,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
......
This diff is collapsed.
...@@ -1127,6 +1127,21 @@ class StableDiffusion3InpaintPipeline(metaclass=DummyObject): ...@@ -1127,6 +1127,21 @@ class StableDiffusion3InpaintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableDiffusion3PAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusion3Pipeline(metaclass=DummyObject): class StableDiffusion3Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
import inspect
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
)
from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = StableDiffusion3PAGPipeline
params = frozenset(
[
"prompt",
"height",
"width",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
patch_size=1,
in_channels=4,
num_layers=2,
attention_head_dim=8,
num_attention_heads=4,
caption_projection_dim=32,
joint_attention_dim=32,
pooled_projection_dim=64,
out_channels=4,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"text_encoder_3": text_encoder_3,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"pag_scale": 0.0,
}
return inputs
def test_stable_diffusion_3_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
inputs["prompt_3"] = "another different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
def test_stable_diffusion_3_different_negative_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt_2"] = "deformed"
inputs["negative_prompt_3"] = "blurry"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
assert max_diff > 1e-2
def test_stable_diffusion_3_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
do_classifier_free_guidance = inputs["guidance_scale"] > 1
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(
prompt,
prompt_2=None,
prompt_3=None,
do_classifier_free_guidance=do_classifier_free_guidance,
device=torch_device,
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline (expect same output when pag is disabled)
pipe_sd = StableDiffusion3Pipeline(**components)
pipe_sd = pipe_sd.to(device)
pipe_sd.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
assert (
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
# pag disabled with pag_scale=0.0
pipe_pag = self.pipeline_class(**components)
pipe_pag = pipe_pag.to(device)
pipe_pag.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["pag_scale"] = 0.0
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
def test_pag_applied_layers(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
# base pipeline
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
all_self_attn_layers = [k for k in pipe.transformer.attn_processors.keys() if "attn" in k]
original_attn_procs = pipe.transformer.attn_processors
pag_layers = ["blocks.0", "blocks.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(all_self_attn_layers)
# blocks.0
block_0_self_attn = ["transformer_blocks.0.attn.processor"]
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0.attn"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert set(pipe.pag_attn_processors) == set(block_0_self_attn)
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.(0|1)"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert (len(pipe.pag_attn_processors)) == 2
pipe.transformer.set_attn_processor(original_attn_procs.copy())
pag_layers = ["blocks.0", r"blocks\.1"]
pipe._set_pag_attn_processor(pag_applied_layers=pag_layers, do_classifier_free_guidance=False)
assert len(pipe.pag_attn_processors) == 2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment