Unverified Commit cd892041 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[DC-AE] Add the official Deep Compression Autoencoder code(32x,64x,128x compression ratio); (#9708)



* first add a script for DC-AE;

* DC-AE init

* replace triton with custom implementation

* 1. rename file and remove un-used codes;

* no longer rely on omegaconf and dataclass

* replace custom activation with diffuers activation

* remove dc_ae attention in attention_processor.py

* iinherit from ModelMixin

* inherit from ConfigMixin

* dc-ae reduce to one file

* update downsample and upsample

* clean code

* support DecoderOutput

* remove get_same_padding and val2tuple

* remove autocast and some assert

* update ResBlock

* remove contents within super().__init__

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove opsequential

* update other blocks to support the removal of build_norm

* remove build encoder/decoder project in/out

* remove inheritance of RMSNorm2d from LayerNorm

* remove reset_parameters for RMSNorm2d
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove device and dtype in RMSNorm2d __init__
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove op_list & build_block

* remove build_stage_main

* change file name to autoencoder_dc

* move LiteMLA to attention.py

* align with other vae decode output;

* add DC-AE into init files;

* update

* make quality && make style;

* quick push before dgx disappears again

* update

* make style

* update

* update

* fix

* refactor

* refactor

* refactor

* update

* possibly change to nn.Linear

* refactor

* make fix-copies

* replace vae with ae

* replace get_block_from_block_type to get_block

* replace downsample_block_type from Conv to conv for consistency

* add scaling factors

* incorporate changes for all checkpoints

* make style

* move mla to attention processor file; split qkv conv to linears

* refactor

* add tests

* from original file loader

* add docs

* add standard autoencoder methods

* combine attention processor

* fix tests

* update

* minor fix

* minor fix

* minor fix & in/out shortcut rename

* minor fix

* make style

* fix paper link

* update docs

* update single file loading

* make style

* remove single file loading support; todo for DN6

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* add abstract

---------
Co-authored-by: default avatarJunyu Chen <chenjydl2003@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarchenjy2003 <70215701+chenjy2003@users.noreply.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 6394d905
...@@ -314,6 +314,8 @@ ...@@ -314,6 +314,8 @@
title: AutoencoderKLMochi title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl - local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
title: AutoencoderDC
- local: api/models/consistency_decoder_vae - local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck - local: api/models/autoencoder_oobleck
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderDC
The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.
The abstract from the paper is:
*We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at [this https URL](https://github.com/mit-han-lab/efficientvit).*
The following DCAE models are released and supported in Diffusers.
| Diffusers format | Original format |
|:----------------:|:---------------:|
| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
```python
from diffusers import AutoencoderDC
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderDC
[[autodoc]] AutoencoderDC
- encode
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
import argparse
from typing import Any, Dict
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from diffusers import AutoencoderDC
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
qkv = state_dict.pop(key)
q, k, v = torch.chunk(qkv, 3, dim=0)
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
parent_module, _, _ = key.rpartition(".proj.conv.weight")
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
AE_KEYS_RENAME_DICT = {
# common
"main.": "",
"op_list.": "",
"context_module": "attn",
"local_module": "conv_out",
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
# If there were more scales, there would be more layers, so a loop would be better to handle this
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
"depth_conv.conv": "conv_depth",
"inverted_conv.conv": "conv_inverted",
"point_conv.conv": "conv_point",
"point_conv.norm": "norm",
"conv.conv.": "conv.",
"conv1.conv": "conv1",
"conv2.conv": "conv2",
"conv2.norm": "norm",
"proj.norm": "norm_out",
# encoder
"encoder.project_in.conv": "encoder.conv_in",
"encoder.project_out.0.conv": "encoder.conv_out",
"encoder.stages": "encoder.down_blocks",
# decoder
"decoder.project_in.conv": "decoder.conv_in",
"decoder.project_out.0": "decoder.norm_out",
"decoder.project_out.2.conv": "decoder.conv_out",
"decoder.stages": "decoder.up_blocks",
}
AE_F32C32_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F64C128_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_F128C512_KEYS = {
# encoder
"encoder.project_in.conv": "encoder.conv_in.conv",
# decoder
"decoder.project_out.2.conv": "decoder.conv_out.conv",
}
AE_SPECIAL_KEYS_REMAP = {
"qkv.conv.weight": remap_qkv_,
"proj.conv.weight": remap_proj_conv_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_ae(config_name: str, dtype: torch.dtype):
config = get_ae_config(config_name)
hub_id = f"mit-han-lab/{config_name}"
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
original_state_dict = get_state_dict(load_file(ckpt_path))
ae = AutoencoderDC(**config).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
ae.load_state_dict(original_state_dict, strict=True)
return ae
def get_ae_config(name: str):
if name in ["dc-ae-f32c32-sana-1.0"]:
config = {
"latent_channels": 32,
"encoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
config = {
"latent_channels": 32,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
}
if name == "dc-ae-f32c32-in-1.0":
config["scaling_factor"] = 0.3189
elif name == "dc-ae-f32c32-mix-1.0":
config["scaling_factor"] = 0.4552
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
config = {
"latent_channels": 128,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f64c128-in-1.0":
config["scaling_factor"] = 0.2889
elif name == "dc-ae-f64c128-mix-1.0":
config["scaling_factor"] = 0.4538
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
config = {
"latent_channels": 512,
"encoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"decoder_block_types": [
"ResBlock",
"ResBlock",
"ResBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
"EfficientViTBlock",
],
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
"decoder_norm_types": [
"batch_norm",
"batch_norm",
"batch_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
"rms_norm",
],
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
}
if name == "dc-ae-f128c512-in-1.0":
config["scaling_factor"] = 0.4883
elif name == "dc-ae-f128c512-mix-1.0":
config["scaling_factor"] = 0.3620
else:
raise ValueError("Invalid config name provided.")
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_name",
type=str,
default="dc-ae-f32c32-sana-1.0",
choices=[
"dc-ae-f32c32-sana-1.0",
"dc-ae-f32c32-in-1.0",
"dc-ae-f32c32-mix-1.0",
"dc-ae-f64c128-in-1.0",
"dc-ae-f64c128-mix-1.0",
"dc-ae-f128c512-in-1.0",
"dc-ae-f128c512-mix-1.0",
],
help="The DCAE checkpoint to convert",
)
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
args = get_args()
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
ae = convert_ae(args.config_name, dtype)
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
...@@ -80,6 +80,7 @@ else: ...@@ -80,6 +80,7 @@ else:
"AllegroTransformer3DModel", "AllegroTransformer3DModel",
"AsymmetricAutoencoderKL", "AsymmetricAutoencoderKL",
"AuraFlowTransformer2DModel", "AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL", "AutoencoderKL",
"AutoencoderKLAllegro", "AutoencoderKLAllegro",
"AutoencoderKLCogVideoX", "AutoencoderKLCogVideoX",
...@@ -572,6 +573,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -572,6 +573,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AllegroTransformer3DModel, AllegroTransformer3DModel,
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL, AutoencoderKL,
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
......
...@@ -27,6 +27,7 @@ _import_structure = {} ...@@ -27,6 +27,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
...@@ -88,6 +89,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -88,6 +89,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter from .adapter import MultiAdapter, T2IAdapter
from .autoencoders import ( from .autoencoders import (
AsymmetricAutoencoderKL, AsymmetricAutoencoderKL,
AutoencoderDC,
AutoencoderKL, AutoencoderKL,
AutoencoderKLAllegro, AutoencoderKLAllegro,
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
......
...@@ -752,6 +752,98 @@ class Attention(nn.Module): ...@@ -752,6 +752,98 @@ class Attention(nn.Module):
self.fused_projections = fuse self.fused_projections = fuse
class SanaMultiscaleAttentionProjection(nn.Module):
def __init__(
self,
in_channels: int,
num_attention_heads: int,
kernel_size: int,
) -> None:
super().__init__()
channels = 3 * in_channels
self.proj_in = nn.Conv2d(
channels,
channels,
kernel_size,
padding=kernel_size // 2,
groups=channels,
bias=False,
)
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_out(hidden_states)
return hidden_states
class SanaMultiscaleLinearAttention(nn.Module):
r"""Lightweight multi-scale linear attention"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 8,
mult: float = 1.0,
norm_type: str = "batch_norm",
kernel_sizes: Tuple[int, ...] = (5,),
eps: float = 1e-15,
residual_connection: bool = False,
):
super().__init__()
# To prevent circular import
from .normalization import get_normalization
self.eps = eps
self.attention_head_dim = attention_head_dim
self.norm_type = norm_type
self.residual_connection = residual_connection
num_attention_heads = (
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
)
inner_dim = num_attention_heads * attention_head_dim
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
self.to_qkv_multiscale = nn.ModuleList()
for kernel_size in kernel_sizes:
self.to_qkv_multiscale.append(
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
)
self.nonlinearity = nn.ReLU()
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
self.norm_out = get_normalization(norm_type, num_features=out_channels)
self.processor = SanaMultiscaleAttnProcessor2_0()
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
scores = torch.matmul(value, key.transpose(-1, -2))
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
return hidden_states
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores)
return hidden_states
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.processor(self, hidden_states)
class AttnProcessor: class AttnProcessor:
r""" r"""
Default processor for performing attention-related computations. Default processor for performing attention-related computations.
...@@ -5007,6 +5099,66 @@ class PAGCFGIdentitySelfAttnProcessor2_0: ...@@ -5007,6 +5099,66 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
return hidden_states return hidden_states
class SanaMultiscaleAttnProcessor2_0:
r"""
Processor for implementing multiscale quadratic attention.
"""
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
height, width = hidden_states.shape[-2:]
if height * width > attn.attention_head_dim:
use_linear_attention = True
else:
use_linear_attention = False
residual = hidden_states
batch_size, _, height, width = list(hidden_states.size())
original_dtype = hidden_states.dtype
hidden_states = hidden_states.movedim(1, -1)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
hidden_states = torch.cat([query, key, value], dim=3)
hidden_states = hidden_states.movedim(-1, 1)
multi_scale_qkv = [hidden_states]
for block in attn.to_qkv_multiscale:
multi_scale_qkv.append(block(hidden_states))
hidden_states = torch.cat(multi_scale_qkv, dim=1)
if use_linear_attention:
# for linear attention upcast hidden_states to float32
hidden_states = hidden_states.to(dtype=torch.float32)
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
query, key, value = hidden_states.chunk(3, dim=2)
query = attn.nonlinearity(query)
key = attn.nonlinearity(key)
if use_linear_attention:
hidden_states = attn.apply_linear_attention(query, key, value)
hidden_states = hidden_states.to(dtype=original_dtype)
else:
hidden_states = attn.apply_quadratic_attention(query, key, value)
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if attn.norm_type == "rms_norm":
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
else:
hidden_states = attn.norm_out(hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class LoRAAttnProcessor: class LoRAAttnProcessor:
def __init__(self): def __init__(self):
pass pass
......
from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
......
This diff is collapsed.
...@@ -30,6 +30,19 @@ from ..unets.unet_2d_blocks import ( ...@@ -30,6 +30,19 @@ from ..unets.unet_2d_blocks import (
) )
@dataclass
class EncoderOutput(BaseOutput):
r"""
Output of encoding method.
Args:
latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
The encoded latent.
"""
latent: torch.Tensor
@dataclass @dataclass
class DecoderOutput(BaseOutput): class DecoderOutput(BaseOutput):
r""" r"""
......
...@@ -512,20 +512,24 @@ else: ...@@ -512,20 +512,24 @@ else:
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True): def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.elementwise_affine = elementwise_affine
if isinstance(dim, numbers.Integral): if isinstance(dim, numbers.Integral):
dim = (dim,) dim = (dim,)
self.dim = torch.Size(dim) self.dim = torch.Size(dim)
self.weight = None
self.bias = None
if elementwise_affine: if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
else: if bias:
self.weight = None self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
...@@ -537,6 +541,8 @@ class RMSNorm(nn.Module): ...@@ -537,6 +541,8 @@ class RMSNorm(nn.Module):
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight hidden_states = hidden_states * self.weight
if self.bias is not None:
hidden_states = hidden_states + self.bias
else: else:
hidden_states = hidden_states.to(input_dtype) hidden_states = hidden_states.to(input_dtype)
...@@ -566,3 +572,21 @@ class LpNorm(nn.Module): ...@@ -566,3 +572,21 @@ class LpNorm(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps) return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
def get_normalization(
norm_type: str = "batch_norm",
num_features: Optional[int] = None,
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
) -> nn.Module:
if norm_type == "rms_norm":
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
elif norm_type == "layer_norm":
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
elif norm_type == "batch_norm":
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
else:
raise ValueError(f"{norm_type=} is not supported.")
return norm
...@@ -47,6 +47,21 @@ class AuraFlowTransformer2DModel(metaclass=DummyObject): ...@@ -47,6 +47,21 @@ class AuraFlowTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoencoderDC(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKL(metaclass=DummyObject): class AutoencoderKL(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderDC
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_dc_config(self):
return {
"in_channels": 3,
"latent_channels": 4,
"attention_head_dim": 2,
"encoder_block_types": (
"ResBlock",
"EfficientViTBlock",
),
"decoder_block_types": (
"ResBlock",
"EfficientViTBlock",
),
"encoder_block_out_channels": (8, 8),
"decoder_block_out_channels": (8, 8),
"encoder_qkv_multiscales": ((), (5,)),
"decoder_qkv_multiscales": ((), (5,)),
"encoder_layers_per_block": (1, 1),
"decoder_layers_per_block": [1, 1],
"downsample_block_type": "conv",
"upsample_block_type": "interpolate",
"decoder_norm_types": "rms_norm",
"decoder_act_fns": "silu",
"scaling_factor": 0.41407,
}
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment