Unverified Commit 2261510b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] Add AuraFlow (#8796)



* add lavender flow transformer

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 87b9db64
import argparse
import torch
from huggingface_hub import hf_hub_download
from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
def load_original_state_dict(args):
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
state_dict = torch.load(model_pt, map_location="cpu")
return state_dict
def calculate_layers(state_dict_keys, key_prefix):
dit_layers = set()
for k in state_dict_keys:
if key_prefix in k:
dit_layers.add(int(k.split(".")[2]))
print(f"{key_prefix}: {len(dit_layers)}")
return len(dit_layers)
# similar to SD3 but only for the last norm layer
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_transformer(state_dict):
converted_state_dict = {}
state_dict_keys = list(state_dict.keys())
converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
# MMDiT blocks 🎸.
for i in range(mmdit_layers):
# feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for orig_k, diffuser_k in path_mapping.items():
for k, v in weight_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
f"model.double_layers.{i}.{orig_k}.{k}.weight"
)
# norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
f"model.double_layers.{i}.{orig_k}.1.weight"
)
# attns
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
f"model.double_layers.{i}.attn.{k}.weight"
)
# Single-DiT blocks.
for i in range(single_dit_layers):
# feed-forward
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for k, v in mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
f"model.single_layers.{i}.mlp.{k}.weight"
)
# norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
f"model.single_layers.{i}.modCX.1.weight"
)
# attns
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
f"model.single_layers.{i}.attn.{k}.weight"
)
# Final blocks.
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
return converted_state_dict
@torch.no_grad()
def populate_state_dict(args):
original_state_dict = load_original_state_dict(args)
state_dict_keys = list(original_state_dict.keys())
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
converted_state_dict = convert_transformer(original_state_dict)
model_diffusers = AuraFlowTransformer2DModel(
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
)
model_diffusers.load_state_dict(converted_state_dict, strict=True)
return model_diffusers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
parser.add_argument("--dump_path", default="aura-flow", type=str)
parser.add_argument("--hub_id", default=None, type=str)
args = parser.parse_args()
model_diffusers = populate_state_dict(args)
model_diffusers.save_pretrained(args.dump_path)
if args.hub_id is not None:
model_diffusers.push_to_hub(args.hub_id)
...@@ -76,6 +76,7 @@ else: ...@@ -76,6 +76,7 @@ else:
_import_structure["models"].extend( _import_structure["models"].extend(
[ [
"AsymmetricAutoencoderKL", "AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL", "AutoencoderKL",
"AutoencoderKLTemporalDecoder", "AutoencoderKLTemporalDecoder",
"AutoencoderTiny", "AutoencoderTiny",
...@@ -235,6 +236,7 @@ else: ...@@ -235,6 +236,7 @@ else:
"AudioLDM2ProjectionModel", "AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel", "AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline", "AudioLDMPipeline",
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline", "BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline", "BlipDiffusionPipeline",
"ChatGLMModel", "ChatGLMModel",
...@@ -507,6 +509,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -507,6 +509,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .models import ( from .models import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL, AutoencoderKL,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderTiny, AutoencoderTiny,
...@@ -646,6 +649,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -646,6 +649,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2ProjectionModel, AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel, AudioLDM2UNet2DConditionModel,
AudioLDMPipeline, AudioLDMPipeline,
AuraFlowPipeline,
ChatGLMModel, ChatGLMModel,
ChatGLMTokenizer, ChatGLMTokenizer,
CLIPImageProjection, CLIPImageProjection,
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"] _import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"] _import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
...@@ -84,6 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -84,6 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .embeddings import ImageProjection from .embeddings import ImageProjection
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .transformers import ( from .transformers import (
AuraFlowTransformer2DModel,
DiTTransformer2DModel, DiTTransformer2DModel,
DualTransformer2DModel, DualTransformer2DModel,
HunyuanDiT2DModel, HunyuanDiT2DModel,
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from ..image_processor import IPAdapterMaskProcessor from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -104,6 +104,7 @@ class Attention(nn.Module): ...@@ -104,6 +104,7 @@ class Attention(nn.Module):
cross_attention_norm_num_groups: int = 32, cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
norm_num_groups: Optional[int] = None, norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None, spatial_norm_dim: Optional[int] = None,
out_bias: bool = True, out_bias: bool = True,
...@@ -118,6 +119,10 @@ class Attention(nn.Module): ...@@ -118,6 +119,10 @@ class Attention(nn.Module):
context_pre_only=None, context_pre_only=None,
): ):
super().__init__() super().__init__()
# To prevent circular import.
from .normalization import FP32LayerNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim self.query_dim = query_dim
...@@ -170,6 +175,9 @@ class Attention(nn.Module): ...@@ -170,6 +175,9 @@ class Attention(nn.Module):
elif qk_norm == "layer_norm": elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps) self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps) self.norm_k = nn.LayerNorm(dim_head, eps=eps)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "layer_norm_across_heads": elif qk_norm == "layer_norm_across_heads":
# Lumina applys qk norm across all heads # Lumina applys qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
...@@ -211,10 +219,10 @@ class Attention(nn.Module): ...@@ -211,10 +219,10 @@ class Attention(nn.Module):
self.to_v = None self.to_v = None
if self.added_kv_proj_dim is not None: if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim) self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim) self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None: if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim) self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_out = nn.ModuleList([]) self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
...@@ -223,6 +231,14 @@ class Attention(nn.Module): ...@@ -223,6 +231,14 @@ class Attention(nn.Module):
if self.context_pre_only is not None and not self.context_pre_only: if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias) self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
else:
self.norm_added_q = None
self.norm_added_k = None
# set attention processor # set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
...@@ -1137,6 +1153,100 @@ class FusedJointAttnProcessor2_0: ...@@ -1137,6 +1153,100 @@ class FusedJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
i=0,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class XFormersAttnAddedKVProcessor: class XFormersAttnAddedKVProcessor:
r""" r"""
Processor for implementing memory efficient attention using xFormers. Processor for implementing memory efficient attention using xFormers.
......
...@@ -473,11 +473,12 @@ class TimestepEmbedding(nn.Module): ...@@ -473,11 +473,12 @@ class TimestepEmbedding(nn.Module):
class Timesteps(nn.Module): class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__() super().__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps): def forward(self, timesteps):
t_emb = get_timestep_embedding( t_emb = get_timestep_embedding(
...@@ -485,6 +486,7 @@ class Timesteps(nn.Module): ...@@ -485,6 +486,7 @@ class Timesteps(nn.Module):
self.num_channels, self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos, flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift, downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
) )
return t_emb return t_emb
......
...@@ -51,6 +51,18 @@ class AdaLayerNorm(nn.Module): ...@@ -51,6 +51,18 @@ class AdaLayerNorm(nn.Module):
return x return x
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)
class AdaLayerNormZero(nn.Module): class AdaLayerNormZero(nn.Module):
r""" r"""
Norm layer adaptive layer norm zero (adaLN-Zero). Norm layer adaptive layer norm zero (adaLN-Zero).
...@@ -60,7 +72,7 @@ class AdaLayerNormZero(nn.Module): ...@@ -60,7 +72,7 @@ class AdaLayerNormZero(nn.Module):
num_embeddings (`int`): The size of the embeddings dictionary. num_embeddings (`int`): The size of the embeddings dictionary.
""" """
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__() super().__init__()
if num_embeddings is not None: if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
...@@ -68,8 +80,15 @@ class AdaLayerNormZero(nn.Module): ...@@ -68,8 +80,15 @@ class AdaLayerNormZero(nn.Module):
self.emb = None self.emb = None
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward( def forward(
self, self,
......
...@@ -2,6 +2,7 @@ from ...utils import is_torch_available ...@@ -2,6 +2,7 @@ from ...utils import is_torch_available
if is_torch_available(): if is_torch_available():
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
from .dit_transformer_2d import DiTTransformer2DModel from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel from .hunyuan_transformer_2d import HunyuanDiT2DModel
......
# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormZero, FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Taken from the original aura flow inference code.
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
# Aura Flow patch embed doesn't use convs for projections.
# Additionally, it uses learned positional embeddings.
class AuraFlowPatchEmbed(nn.Module):
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
pos_embed_max_size=None,
):
super().__init__()
self.num_patches = (height // patch_size) * (width // patch_size)
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
num_channels,
height // self.patch_size,
self.patch_size,
width // self.patch_size,
self.patch_size,
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
return latent + self.pos_embed
# Taken from the original Aura flow inference code.
# Our feedforward only has GELU but Aura uses SiLU.
class AuraFlowFeedForward(nn.Module):
def __init__(self, dim, hidden_dim=None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
final_hidden_dim = int(2 * hidden_dim / 3)
final_hidden_dim = find_multiple(final_hidden_dim, 256)
self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.linear_1(x)) * self.linear_2(x)
x = self.out_projection(x)
return x
class AuraFlowPreFinalBlock(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = x * (1 + scale)[:, None, :] + shift[:, None, :]
return x
@maybe_allow_in_graph
class AuraFlowSingleTransformerBlock(nn.Module):
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
def __init__(self, dim, num_attention_heads, attention_head_dim):
super().__init__()
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
processor = AuraFlowAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="fp32_layer_norm",
out_dim=dim,
bias=False,
out_bias=False,
processor=processor,
)
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
residual = hidden_states
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(hidden_states)
hidden_states = gate_mlp.unsqueeze(1) * ff_output
hidden_states = residual + hidden_states
return hidden_states
@maybe_allow_in_graph
class AuraFlowJointTransformerBlock(nn.Module):
r"""
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
* QK Norm in the attention blocks
* No bias in the attention blocks
* Most LayerNorms are in FP32
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
is_last (`bool`): Boolean to determine if this is the last block in the model.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim):
super().__init__()
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
processor = AuraFlowAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
added_proj_bias=False,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="fp32_layer_norm",
out_dim=dim,
bias=False,
out_bias=False,
processor=processor,
context_pre_only=False,
)
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
):
residual = hidden_states
residual_context = encoder_hidden_states
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
hidden_states = residual + hidden_states
# Process attention outputs for the `encoder_hidden_states`.
encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
encoder_hidden_states = residual_context + encoder_hidden_states
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 64,
patch_size: int = 2,
in_channels: int = 4,
num_mmdit_layers: int = 4,
num_single_dit_layers: int = 32,
attention_head_dim: int = 256,
num_attention_heads: int = 12,
joint_attention_dim: int = 2048,
caption_projection_dim: int = 3072,
out_channels: int = 4,
pos_embed_max_size: int = 1024,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = AuraFlowPatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
)
self.context_embedder = nn.Linear(
self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
)
self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
self.joint_transformer_blocks = nn.ModuleList(
[
AuraFlowJointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_mmdit_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
AuraFlowSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(self.config.num_single_dit_layers)
]
)
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
# https://arxiv.org/abs/2309.16588
# prevents artifacts in the attention maps
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
temb = self.time_step_proj(temb)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_states = torch.cat(
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
)
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
if len(self.single_transformer_blocks) > 0:
encoder_seq_len = encoder_hidden_states.size(1)
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
combined_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
combined_hidden_states,
temb,
**ckpt_kwargs,
)
else:
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
hidden_states = combined_hidden_states[:, encoder_seq_len:]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
out_channels = self.config.out_channels
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
...@@ -29,20 +28,12 @@ from ..embeddings import ( ...@@ -29,20 +28,12 @@ from ..embeddings import (
) )
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
).to(origin_dtype)
class AdaLayerNormShift(nn.Module): class AdaLayerNormShift(nn.Module):
r""" r"""
Norm layer modified to incorporate timestep embeddings. Norm layer modified to incorporate timestep embeddings.
......
...@@ -250,6 +250,7 @@ else: ...@@ -250,6 +250,7 @@ else:
"StableDiffusionLDM3DPipeline", "StableDiffusionLDM3DPipeline",
] ]
) )
_import_structure["aura_flow"] = ["AuraFlowPipeline"]
_import_structure["stable_diffusion_3"] = [ _import_structure["stable_diffusion_3"] = [
"StableDiffusion3Pipeline", "StableDiffusion3Pipeline",
"StableDiffusion3Img2ImgPipeline", "StableDiffusion3Img2ImgPipeline",
...@@ -418,6 +419,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -418,6 +419,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2ProjectionModel, AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel, AudioLDM2UNet2DConditionModel,
) )
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import ( from .controlnet import (
BlipDiffusionControlNetPipeline, BlipDiffusionControlNetPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_aura_flow import AuraFlowPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -158,7 +158,12 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -158,7 +158,12 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _sigma_to_t(self, sigma): def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
...@@ -168,17 +173,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -168,17 +173,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
self.num_inference_steps = num_inference_steps
timesteps = np.linspace( if sigmas is None:
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps self.num_inference_steps = num_inference_steps
) timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device) self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
......
...@@ -17,6 +17,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject): ...@@ -17,6 +17,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKL(metaclass=DummyObject): class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -182,6 +182,21 @@ class AudioLDMPipeline(metaclass=DummyObject): ...@@ -182,6 +182,21 @@ class AudioLDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class AuraFlowPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ChatGLMModel(metaclass=DummyObject): class ChatGLMModel(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AuraFlowTransformer2DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
sequence_length = 256
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"patch_size": 2,
"in_channels": 4,
"num_mmdit_layers": 1,
"num_single_dit_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
"joint_attention_dim": 32,
"out_channels": 4,
"pos_embed_max_size": 256,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = AuraFlowPipeline
params = frozenset(
[
"prompt",
"height",
"width",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = AuraFlowTransformer2DModel(
sample_size=32,
patch_size=2,
in_channels=4,
num_mmdit_layers=1,
num_single_dit_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_projection_dim=32,
joint_attention_dim=32,
out_channels=4,
pos_embed_max_size=256,
)
text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=32,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"height": None,
"width": None,
}
return inputs
def test_aura_flow_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
do_classifier_free_guidance = inputs["guidance_scale"] > 1
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(
prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
device=torch_device,
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other.
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment