Unverified Commit 2261510b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] Add AuraFlow (#8796)



* add lavender flow transformer

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 87b9db64
import argparse
import torch
from huggingface_hub import hf_hub_download
from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
def load_original_state_dict(args):
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
state_dict = torch.load(model_pt, map_location="cpu")
return state_dict
def calculate_layers(state_dict_keys, key_prefix):
dit_layers = set()
for k in state_dict_keys:
if key_prefix in k:
dit_layers.add(int(k.split(".")[2]))
print(f"{key_prefix}: {len(dit_layers)}")
return len(dit_layers)
# similar to SD3 but only for the last norm layer
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_transformer(state_dict):
converted_state_dict = {}
state_dict_keys = list(state_dict.keys())
converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
# MMDiT blocks 🎸.
for i in range(mmdit_layers):
# feed-forward
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for orig_k, diffuser_k in path_mapping.items():
for k, v in weight_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
f"model.double_layers.{i}.{orig_k}.{k}.weight"
)
# norms
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
for orig_k, diffuser_k in path_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
f"model.double_layers.{i}.{orig_k}.1.weight"
)
# attns
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
for k, v in attn_mapping.items():
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
f"model.double_layers.{i}.attn.{k}.weight"
)
# Single-DiT blocks.
for i in range(single_dit_layers):
# feed-forward
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
for k, v in mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
f"model.single_layers.{i}.mlp.{k}.weight"
)
# norms
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
f"model.single_layers.{i}.modCX.1.weight"
)
# attns
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
for k, v in x_attn_mapping.items():
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
f"model.single_layers.{i}.attn.{k}.weight"
)
# Final blocks.
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
return converted_state_dict
@torch.no_grad()
def populate_state_dict(args):
original_state_dict = load_original_state_dict(args)
state_dict_keys = list(original_state_dict.keys())
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
converted_state_dict = convert_transformer(original_state_dict)
model_diffusers = AuraFlowTransformer2DModel(
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
)
model_diffusers.load_state_dict(converted_state_dict, strict=True)
return model_diffusers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
parser.add_argument("--dump_path", default="aura-flow", type=str)
parser.add_argument("--hub_id", default=None, type=str)
args = parser.parse_args()
model_diffusers = populate_state_dict(args)
model_diffusers.save_pretrained(args.dump_path)
if args.hub_id is not None:
model_diffusers.push_to_hub(args.hub_id)
......@@ -76,6 +76,7 @@ else:
_import_structure["models"].extend(
[
"AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
......@@ -235,6 +236,7 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline",
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChatGLMModel",
......@@ -507,6 +509,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .models import (
AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
......@@ -646,6 +649,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChatGLMModel,
ChatGLMTokenizer,
CLIPImageProjection,
......
......@@ -38,6 +38,7 @@ if is_torch_available():
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
......@@ -84,6 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
AuraFlowTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
HunyuanDiT2DModel,
......
......@@ -22,7 +22,7 @@ from torch import nn
from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -104,6 +104,7 @@ class Attention(nn.Module):
cross_attention_norm_num_groups: int = 32,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
......@@ -118,6 +119,10 @@ class Attention(nn.Module):
context_pre_only=None,
):
super().__init__()
# To prevent circular import.
from .normalization import FP32LayerNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
......@@ -170,6 +175,9 @@ class Attention(nn.Module):
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "layer_norm_across_heads":
# Lumina applys qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
......@@ -211,10 +219,10 @@ class Attention(nn.Module):
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
......@@ -223,6 +231,14 @@ class Attention(nn.Module):
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
else:
self.norm_added_q = None
self.norm_added_k = None
# set attention processor
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
......@@ -1137,6 +1153,100 @@ class FusedJointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
class AuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
i=0,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# `context` projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
......
......@@ -473,11 +473,12 @@ class TimestepEmbedding(nn.Module):
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
......@@ -485,6 +486,7 @@ class Timesteps(nn.Module):
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
......
......@@ -51,6 +51,18 @@ class AdaLayerNorm(nn.Module):
return x
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)
class AdaLayerNormZero(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
......@@ -60,7 +72,7 @@ class AdaLayerNormZero(nn.Module):
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
......@@ -68,8 +80,15 @@ class AdaLayerNormZero(nn.Module):
self.emb = None
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
......
......@@ -2,6 +2,7 @@ from ...utils import is_torch_available
if is_torch_available():
from .auraflow_transformer_2d import AuraFlowTransformer2DModel
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel
......
# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import Attention, AuraFlowAttnProcessor2_0
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormZero, FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Taken from the original aura flow inference code.
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
# Aura Flow patch embed doesn't use convs for projections.
# Additionally, it uses learned positional embeddings.
class AuraFlowPatchEmbed(nn.Module):
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
pos_embed_max_size=None,
):
super().__init__()
self.num_patches = (height // patch_size) * (width // patch_size)
self.pos_embed_max_size = pos_embed_max_size
self.proj = nn.Linear(patch_size * patch_size * in_channels, embed_dim)
self.pos_embed = nn.Parameter(torch.randn(1, pos_embed_max_size, embed_dim) * 0.1)
self.patch_size = patch_size
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
num_channels,
height // self.patch_size,
self.patch_size,
width // self.patch_size,
self.patch_size,
)
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
latent = self.proj(latent)
return latent + self.pos_embed
# Taken from the original Aura flow inference code.
# Our feedforward only has GELU but Aura uses SiLU.
class AuraFlowFeedForward(nn.Module):
def __init__(self, dim, hidden_dim=None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
final_hidden_dim = int(2 * hidden_dim / 3)
final_hidden_dim = find_multiple(final_hidden_dim, 256)
self.linear_1 = nn.Linear(dim, final_hidden_dim, bias=False)
self.linear_2 = nn.Linear(dim, final_hidden_dim, bias=False)
self.out_projection = nn.Linear(final_hidden_dim, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.linear_1(x)) * self.linear_2(x)
x = self.out_projection(x)
return x
class AuraFlowPreFinalBlock(nn.Module):
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=False)
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = x * (1 + scale)[:, None, :] + shift[:, None, :]
return x
@maybe_allow_in_graph
class AuraFlowSingleTransformerBlock(nn.Module):
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
def __init__(self, dim, num_attention_heads, attention_head_dim):
super().__init__()
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
processor = AuraFlowAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="fp32_layer_norm",
out_dim=dim,
bias=False,
out_bias=False,
processor=processor,
)
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
residual = hidden_states
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(hidden_states)
hidden_states = gate_mlp.unsqueeze(1) * ff_output
hidden_states = residual + hidden_states
return hidden_states
@maybe_allow_in_graph
class AuraFlowJointTransformerBlock(nn.Module):
r"""
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
* QK Norm in the attention blocks
* No bias in the attention blocks
* Most LayerNorms are in FP32
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
is_last (`bool`): Boolean to determine if this is the last block in the model.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim):
super().__init__()
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
processor = AuraFlowAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
added_proj_bias=False,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="fp32_layer_norm",
out_dim=dim,
bias=False,
out_bias=False,
processor=processor,
context_pre_only=False,
)
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
self.norm2_context = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
):
residual = hidden_states
residual_context = encoder_hidden_states
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
hidden_states = hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
hidden_states = gate_mlp.unsqueeze(1) * self.ff(hidden_states)
hidden_states = residual + hidden_states
# Process attention outputs for the `encoder_hidden_states`.
encoder_hidden_states = self.norm2_context(residual_context + c_gate_msa.unsqueeze(1) * context_attn_output)
encoder_hidden_states = encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
encoder_hidden_states = c_gate_mlp.unsqueeze(1) * self.ff_context(encoder_hidden_states)
encoder_hidden_states = residual_context + encoder_hidden_states
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: int = 64,
patch_size: int = 2,
in_channels: int = 4,
num_mmdit_layers: int = 4,
num_single_dit_layers: int = 32,
attention_head_dim: int = 256,
num_attention_heads: int = 12,
joint_attention_dim: int = 2048,
caption_projection_dim: int = 3072,
out_channels: int = 4,
pos_embed_max_size: int = 1024,
):
super().__init__()
default_out_channels = in_channels
self.out_channels = out_channels if out_channels is not None else default_out_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = AuraFlowPatchEmbed(
height=self.config.sample_size,
width=self.config.sample_size,
patch_size=self.config.patch_size,
in_channels=self.config.in_channels,
embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size,
)
self.context_embedder = nn.Linear(
self.config.joint_attention_dim, self.config.caption_projection_dim, bias=False
)
self.time_step_embed = Timesteps(num_channels=256, downscale_freq_shift=0, scale=1000, flip_sin_to_cos=True)
self.time_step_proj = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
self.joint_transformer_blocks = nn.ModuleList(
[
AuraFlowJointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_mmdit_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
AuraFlowSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for _ in range(self.config.num_single_dit_layers)
]
)
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
# https://arxiv.org/abs/2309.16588
# prevents artifacts in the attention maps
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
temb = self.time_step_proj(temb)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
encoder_hidden_states = torch.cat(
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
)
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
if len(self.single_transformer_blocks) > 0:
encoder_seq_len = encoder_hidden_states.size(1)
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
combined_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
combined_hidden_states,
temb,
**ckpt_kwargs,
)
else:
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
hidden_states = combined_hidden_states[:, encoder_seq_len:]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
out_channels = self.config.out_channels
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -14,7 +14,6 @@
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
......@@ -29,20 +28,12 @@ from ..embeddings import (
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
return F.layer_norm(
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
).to(origin_dtype)
class AdaLayerNormShift(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
......
......@@ -250,6 +250,7 @@ else:
"StableDiffusionLDM3DPipeline",
]
)
_import_structure["aura_flow"] = ["AuraFlowPipeline"]
_import_structure["stable_diffusion_3"] = [
"StableDiffusion3Pipeline",
"StableDiffusion3Img2ImgPipeline",
......@@ -418,6 +419,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_aura_flow"] = ["AuraFlowPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_aura_flow import AuraFlowPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
......@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -158,7 +158,12 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
......@@ -168,17 +173,19 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
if sigmas is None:
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
......
......@@ -17,6 +17,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -182,6 +182,21 @@ class AudioLDMPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class AuraFlowPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ChatGLMModel(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AuraFlowTransformer2DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
sequence_length = 256
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"patch_size": 2,
"in_channels": 4,
"num_mmdit_layers": 1,
"num_single_dit_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
"joint_attention_dim": 32,
"out_channels": 4,
"pos_embed_max_size": 256,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
class AuraFlowPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = AuraFlowPipeline
params = frozenset(
[
"prompt",
"height",
"width",
"guidance_scale",
"negative_prompt",
"prompt_embeds",
"negative_prompt_embeds",
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = AuraFlowTransformer2DModel(
sample_size=32,
patch_size=2,
in_channels=4,
num_mmdit_layers=1,
num_single_dit_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_projection_dim=32,
joint_attention_dim=32,
out_channels=4,
pos_embed_max_size=256,
)
text_encoder = UMT5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-umt5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=32,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"height": None,
"width": None,
}
return inputs
def test_aura_flow_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
do_classifier_free_guidance = inputs["guidance_scale"] > 1
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = pipe.encode_prompt(
prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
device=torch_device,
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
def test_attention_slicing_forward_pass(self):
# Attention slicing needs to implemented differently for this because how single DiT and MMDiT
# blocks interfere with each other.
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment