Unverified Commit a2bc2e14 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[feat] allow SDXL pipeline to run with fused QKV projections (#6030)



* debug

* from step

* print

* turn sigma a list

* make str

* init_noise_sigma

* comment

* remove prints

* feat: introduce fused projections

* change to a better name

* no grad

* device.

* device

* dtype

* okay

* print

* more print

* fix: unbind -> split

* fix: qkv >-> k

* enable disable

* apply attention processor within the method

* attn processors

* _enable_fused_qkv_projections

* remove print

* add fused projection to vae

* add todos.

* add: documentation and cleanups.

* add: test for qkv projection fusion.

* relax assertions.

* relax further

* fix: docs

* fix-copies

* correct error message.

* Empty-Commit

* better conditioning on disable_fused_qkv_projections

* check

* check processor

* bfloat16 computation.

* check latent dtype

* style

* remove copy temporarily

* cast latent to bfloat16

* fix: vae -> self.vae

* remove print.

* add _change_to_group_norm_32

* comment out stuff that didn't work

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* reflect patrick's suggestions.

* fix imports

* fix: disable call.

* fix more

* fix device and dtype

* fix conditions.

* fix more

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent f427345a
...@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech ...@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0 ## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0 [[autodoc]] models.attention_processor.AttnProcessor2_0
## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## LoRAAttnProcessor ## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor [[autodoc]] models.attention_processor.LoRAAttnProcessor
......
...@@ -113,12 +113,14 @@ class Attention(nn.Module): ...@@ -113,12 +113,14 @@ class Attention(nn.Module):
): ):
super().__init__() super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection self.residual_connection = residual_connection
self.dropout = dropout self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded # we make use of this private variable to know whether this class is loaded
...@@ -180,6 +182,7 @@ class Attention(nn.Module): ...@@ -180,6 +182,7 @@ class Attention(nn.Module):
else: else:
linear_cls = LoRACompatibleLinear linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention: if not self.only_cross_attention:
...@@ -692,6 +695,32 @@ class Attention(nn.Module): ...@@ -692,6 +695,32 @@ class Attention(nn.Module):
return encoder_hidden_states return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
self.fused_projections = fuse
class AttnProcessor: class AttnProcessor:
r""" r"""
...@@ -1184,9 +1213,6 @@ class AttnProcessor2_0: ...@@ -1184,9 +1213,6 @@ class AttnProcessor2_0:
scale: float = 1.0, scale: float = 1.0,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -1253,6 +1279,103 @@ class AttnProcessor2_0: ...@@ -1253,6 +1279,103 @@ class AttnProcessor2_0:
return hidden_states return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)
kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module): class CustomDiffusionXFormersAttnProcessor(nn.Module):
r""" r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
...@@ -2251,6 +2374,7 @@ CROSS_ATTENTION_PROCESSORS = ( ...@@ -2251,6 +2374,7 @@ CROSS_ATTENTION_PROCESSORS = (
AttentionProcessor = Union[ AttentionProcessor = Union[
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
SlicedAttnProcessor, SlicedAttnProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
......
...@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook ...@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import ( from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
...@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): ...@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return (dec,) return (dec,)
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
...@@ -25,6 +25,7 @@ from .activations import get_activation ...@@ -25,6 +25,7 @@ from .activations import get_activation
from .attention_processor import ( from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
...@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None) setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -34,6 +34,7 @@ from ...loaders import ( ...@@ -34,6 +34,7 @@ from ...loaders import (
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
...@@ -681,7 +682,6 @@ class StableDiffusionXLPipeline( ...@@ -681,7 +682,6 @@ class StableDiffusionXLPipeline(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids return add_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self): def upcast_vae(self):
dtype = self.vae.dtype dtype = self.vae.dtype
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
...@@ -692,6 +692,7 @@ class StableDiffusionXLPipeline( ...@@ -692,6 +692,7 @@ class StableDiffusionXLPipeline(
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
), ),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
...@@ -729,6 +730,65 @@ class StableDiffusionXLPipeline( ...@@ -729,6 +730,65 @@ class StableDiffusionXLPipeline(
"""Disables the FreeU mechanism if enabled.""" """Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu() self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
""" """
......
...@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te ...@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import ( from ...models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
...@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline( ...@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
), ),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
......
...@@ -10,10 +10,10 @@ from diffusers.utils import deprecate ...@@ -10,10 +10,10 @@ from diffusers.utils import deprecate
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin from ...models import ModelMixin
from ...models.activations import get_activation from ...models.activations import get_activation
from ...models.attention import Attention
from ...models.attention_processor import ( from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor, AttentionProcessor,
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
...@@ -1000,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1000,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None) setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
......
...@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
@property @property
def init_noise_sigma(self): def init_noise_sigma(self):
# standard deviation of the initial noise distribution # standard deviation of the initial noise distribution
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
if self.config.timestep_spacing in ["linspace", "trailing"]: if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max() return max_sigma
return (self.sigmas.max() ** 2 + 1) ** 0.5 return (max_sigma**2 + 1) ** 0.5
@property @property
def step_index(self): def step_index(self):
...@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if sigmas.device.type == "cuda":
self.sigmas = self.sigmas.tolist()
self._step_index = None self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas): def _sigma_to_t(self, sigma, log_sigmas):
......
...@@ -938,6 +938,37 @@ class StableDiffusionXLPipelineFastTests( ...@@ -938,6 +938,37 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3 assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
sd_pipe.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
sd_pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow @slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment