Unverified Commit 50d21f7c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] fix QKV fusion for attention (#8829)

* start debugging the problem,

* start

* fix

* fix

* fix imports.

* handle hunyuan

* remove residuals.

* add a check for making sure there's appropriate procs.

* add more rigor to the tests.

* fix test

* remove redundant check

* fix-copies

* move check_qkv_fusion_matches_attn_procs_length and check_qkv_fusion_processors_exist.
parent 3bb1fd6f
...@@ -677,6 +677,21 @@ class Attention(nn.Module): ...@@ -677,6 +677,21 @@ class Attention(nn.Module):
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias) self.to_kv.bias.copy_(concatenated_bias)
# handle added projections for SD3 and others.
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
self.to_added_qkv.weight.copy_(concatenated_weights)
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse self.fused_projections = fuse
...@@ -1708,6 +1723,109 @@ class HunyuanAttnProcessor2_0: ...@@ -1708,6 +1723,109 @@ class HunyuanAttnProcessor2_0:
return hidden_states return hidden_states
class FusedHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
if not attn.is_cross_attention:
key = apply_rotary_emb(key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class LuminaAttnProcessor2_0: class LuminaAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
......
...@@ -26,6 +26,7 @@ from ..attention_processor import ( ...@@ -26,6 +26,7 @@ from ..attention_processor import (
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0,
) )
from ..modeling_outputs import AutoencoderKLOutput from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -492,6 +493,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -492,6 +493,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -22,7 +22,7 @@ import torch.nn as nn ...@@ -22,7 +22,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
from ..models.attention import JointTransformerBlock from ..models.attention import JointTransformerBlock
from ..models.attention_processor import Attention, AttentionProcessor from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ..models.modeling_outputs import Transformer2DModelOutput from ..models.modeling_outputs import Transformer2DModelOutput
from ..models.modeling_utils import ModelMixin from ..models.modeling_utils import ModelMixin
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
...@@ -196,7 +196,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -196,7 +196,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
...@@ -220,6 +220,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -220,6 +220,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedJointAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -29,6 +29,7 @@ from .attention_processor import ( ...@@ -29,6 +29,7 @@ from .attention_processor import (
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0,
) )
from .controlnet import ControlNetConditioningEmbedding from .controlnet import ControlNetConditioningEmbedding
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
...@@ -1001,6 +1002,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -1001,6 +1002,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config ...@@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0 from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
from ..embeddings import ( from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding, HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed, PatchEmbed,
...@@ -317,7 +317,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -317,7 +317,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
...@@ -341,6 +341,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -341,6 +341,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -23,7 +23,7 @@ import torch.nn as nn ...@@ -23,7 +23,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import JointTransformerBlock from ...models.attention import JointTransformerBlock
from ...models.attention_processor import Attention, AttentionProcessor from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous from ...models.normalization import AdaLayerNormContinuous
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
...@@ -211,7 +211,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -211,7 +211,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
for name, module in self.named_children(): for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor) fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
def fuse_qkv_projections(self): def fuse_qkv_projections(self):
""" """
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
...@@ -235,6 +235,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -235,6 +235,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedJointAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -30,6 +30,7 @@ from ..attention_processor import ( ...@@ -30,6 +30,7 @@ from ..attention_processor import (
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0,
) )
from ..embeddings import ( from ..embeddings import (
GaussianFourierProjection, GaussianFourierProjection,
...@@ -890,6 +891,8 @@ class UNet2DConditionModel( ...@@ -890,6 +891,8 @@ class UNet2DConditionModel(
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -31,6 +31,7 @@ from ..attention_processor import ( ...@@ -31,6 +31,7 @@ from ..attention_processor import (
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0,
) )
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -532,6 +533,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -532,6 +533,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -29,6 +29,7 @@ from ..attention_processor import ( ...@@ -29,6 +29,7 @@ from ..attention_processor import (
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
FusedAttnProcessor2_0,
) )
from ..embeddings import TimestepEmbedding, Timesteps from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -498,6 +499,8 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -498,6 +499,8 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -29,6 +29,7 @@ from ..attention_processor import ( ...@@ -29,6 +29,7 @@ from ..attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
FusedAttnProcessor2_0,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
) )
...@@ -929,6 +930,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -929,6 +930,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if isinstance(module, Attention): if isinstance(module, Attention):
module.fuse_projections(fuse=True) module.fuse_projections(fuse=True)
self.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
......
...@@ -36,7 +36,12 @@ from diffusers.utils.testing_utils import ( ...@@ -36,7 +36,12 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
)
enable_full_determinism() enable_full_determinism()
...@@ -261,6 +266,16 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -261,6 +266,16 @@ class HunyuanDiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
original_image_slice = image[0, -3:, -3:, -1] original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False inputs["return_dict"] = False
image_fused = pipe(**inputs)[0] image_fused = pipe(**inputs)[0]
......
...@@ -13,7 +13,11 @@ from diffusers.utils.testing_utils import ( ...@@ -13,7 +13,11 @@ from diffusers.utils.testing_utils import (
torch_device, torch_device,
) )
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
...@@ -191,7 +195,16 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -191,7 +195,16 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
image = pipe(**inputs).images image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1] original_image_slice = image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections() pipe.transformer.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1] image_slice_fused = image[0, -3:, -3:, -1]
......
...@@ -13,6 +13,7 @@ from typing import Any, Callable, Dict, Union ...@@ -13,6 +13,7 @@ from typing import Any, Callable, Dict, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available from huggingface_hub.utils import is_jinja_available
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -40,7 +41,12 @@ from diffusers.pipelines.pipeline_utils import StableDiffusionMixin ...@@ -40,7 +41,12 @@ from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import CaptureLogger, require_torch, skip_mps, torch_device from diffusers.utils.testing_utils import (
CaptureLogger,
require_torch,
skip_mps,
torch_device,
)
from ..models.autoencoders.test_models_vae import ( from ..models.autoencoders.test_models_vae import (
get_asym_autoencoder_kl_config, get_asym_autoencoder_kl_config,
...@@ -67,6 +73,17 @@ def check_same_shape(tensor_list): ...@@ -67,6 +73,17 @@ def check_same_shape(tensor_list):
return all(shape == shapes[0] for shape in shapes[1:]) return all(shape == shapes[0] for shape in shapes[1:])
def check_qkv_fusion_matches_attn_procs_length(model, original_attn_processors):
current_attn_processors = model.attn_processors
return len(current_attn_processors) == len(original_attn_processors)
def check_qkv_fusion_processors_exist(model):
current_attn_processors = model.attn_processors
proc_names = [v.__class__.__name__ for _, v in current_attn_processors.items()]
return all(p.startswith("Fused") for p in proc_names)
class SDFunctionTesterMixin: class SDFunctionTesterMixin:
""" """
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
...@@ -196,6 +213,19 @@ class SDFunctionTesterMixin: ...@@ -196,6 +213,19 @@ class SDFunctionTesterMixin:
original_image_slice = image[0, -3:, -3:, -1] original_image_slice = image[0, -3:, -3:, -1]
pipe.fuse_qkv_projections() pipe.fuse_qkv_projections()
for _, component in pipe.components.items():
if (
isinstance(component, nn.Module)
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
assert check_qkv_fusion_processors_exist(
component
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
component, component.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device) inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False inputs["return_dict"] = False
image_fused = pipe(**inputs)[0] image_fused = pipe(**inputs)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment