Unverified Commit 5a196e3d authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

[Sana] Add Sana, including `SanaPipeline`, `SanaPAGPipeline`,...


[Sana] Add Sana, including `SanaPipeline`, `SanaPAGPipeline`, `LinearAttentionProcessor`, `Flow-based DPM-sovler` and so on. (#9982)

* first add a script for DC-AE;

* DC-AE init

* replace triton with custom implementation

* 1. rename file and remove un-used codes;

* no longer rely on omegaconf and dataclass

* replace custom activation with diffuers activation

* remove dc_ae attention in attention_processor.py

* iinherit from ModelMixin

* inherit from ConfigMixin

* dc-ae reduce to one file

* update downsample and upsample

* clean code

* support DecoderOutput

* remove get_same_padding and val2tuple

* remove autocast and some assert

* update ResBlock

* remove contents within super().__init__

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove opsequential

* update other blocks to support the removal of build_norm

* remove build encoder/decoder project in/out

* remove inheritance of RMSNorm2d from LayerNorm

* remove reset_parameters for RMSNorm2d
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove device and dtype in RMSNorm2d __init__
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/autoencoders/dc_ae.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* remove op_list & build_block

* remove build_stage_main

* change file name to autoencoder_dc

* move LiteMLA to attention.py

* align with other vae decode output;

* add DC-AE into init files;

* update

* make quality && make style;

* quick push before dgx disappears again

* update

* make style

* update

* update

* fix

* refactor

* refactor

* refactor

* update

* possibly change to nn.Linear

* refactor

* make fix-copies

* replace vae with ae

* replace get_block_from_block_type to get_block

* replace downsample_block_type from Conv to conv for consistency

* add scaling factors

* incorporate changes for all checkpoints

* make style

* move mla to attention processor file; split qkv conv to linears

* refactor

* add tests

* from original file loader

* add docs

* add standard autoencoder methods

* combine attention processor

* fix tests

* update

* minor fix

* minor fix

* minor fix & in/out shortcut rename

* minor fix

* make style

* fix paper link

* update docs

* update single file loading

* make style

* remove single file loading support; todo for DN6

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* add abstract

* 1. add DCAE into diffusers;
2. make style and make quality;

* add DCAE_HF into diffusers;

* bug fixed;

* add SanaPipeline, SanaTransformer2D into diffusers;

* add sanaLinearAttnProcessor2_0;

* first update for SanaTransformer;

* first update for SanaPipeline;

* first success run SanaPipeline;

* model output finally match with original model with the same intput;

* code update;

* code update;

* add a flow dpm-solver scripts

* 🎉[important update]
1. Integrate flow-dpm-sovler into diffusers;
2. finally run successfully on both `FlowMatchEulerDiscreteScheduler` and `FlowDPMSolverMultistepScheduler`;

* 🎉🔧

[important update & fix huge bugs!!]
1. add SanaPAGPipeline & several related Sana linear attention operators;
2. `SanaTransformer2DModel` not supports multi-resolution input;
2. fix the multi-scale HW bugs in SanaPipeline and SanaPAGPipeline;
3. fix the flow-dpm-solver set_timestep() init `model_output` and `lower_order_nums` bugs;

* remove prints;

* add convert sana official checkpoint to diffusers format Safetensor.

* Update src/diffusers/models/transformers/sana_transformer_2d.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/pag/pipeline_pag_sana.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/models/transformers/sana_transformer_2d.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/pipelines/sana/pipeline_sana.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update Sana for DC-AE's recent commit;

* make style && make quality

* Add StableDiffusion3PAGImg2Img Pipeline + Fix SD3 Unconditional PAG (#9932)

* fix progress bar updates in SD 1.5 PAG Img2Img pipeline

---------
Co-authored-by: default avatarVinh H. Pham <phamvinh257@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* make the vae can be None in `__init__` of `SanaPipeline`

* Update src/diffusers/models/transformers/sana_transformer_2d.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* change the ae related code due to the latest update of DCAE branch;

* change the ae related code due to the latest update of DCAE branch;

* 1. change code based on AutoencoderDC;
2. fix the bug of new GLUMBConv;
3. run success;

* update for solving conversation.

* 1. fix bugs and run convert script success;
2. Downloading ckpt from hub automatically;

* make style && make quality;

* 1. remove un-unsed parameters in init;
2. code update;

* remove test file

* refactor; add docs; add tests; update conversion script

* make style

* make fix-copies

* refactor

* udpate pipelines

* pag tests and refactor

* remove sana pag conversion script

* handle weight casting in conversion script

* update conversion script

* add a processor

* 1. add bf16 pth file path;
2. add complex human instruct in pipeline;

* fix fast \tests

* change gemma-2-2b-it ckpt to a non-gated repo;

* fix the pth path bug in conversion script;

* change grad ckpt to original; make style

* fix the complex_human_instruct bug and typo;

* remove dpmsolver flow scheduler

* apply review suggestions

* change the `FlowMatchEulerDiscreteScheduler` to default `DPMSolverMultistepScheduler` with flow matching scheduler.

* fix the tokenizer.padding_side='right' bug;

* update docs

* make fix-copies

* fix imports

* fix docs

* add integration test

* update docs

* update examples

* fix convert_model_output in schedulers

* fix failing tests

---------
Co-authored-by: default avatarJunyu Chen <chenjydl2003@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarchenjy2003 <70215701+chenjy2003@users.noreply.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 22c4f079
...@@ -284,6 +284,8 @@ ...@@ -284,6 +284,8 @@
title: PriorTransformer title: PriorTransformer
- local: api/models/sd3_transformer2d - local: api/models/sd3_transformer2d
title: SD3Transformer2DModel title: SD3Transformer2DModel
- local: api/models/sana_transformer2d
title: SanaTransformer2DModel
- local: api/models/stable_audio_transformer - local: api/models/stable_audio_transformer
title: StableAudioDiTModel title: StableAudioDiTModel
- local: api/models/transformer2d - local: api/models/transformer2d
...@@ -434,6 +436,8 @@ ...@@ -434,6 +436,8 @@
title: PixArt-α title: PixArt-α
- local: api/pipelines/pixart_sigma - local: api/pipelines/pixart_sigma
title: PixArt-Σ title: PixArt-Σ
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/self_attention_guidance - local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion - local: api/pipelines/semantic_stable_diffusion
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# SanaTransformer2DModel
A Diffusion Transformer model for 2D data from [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) was introduced from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
The abstract from the paper is:
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
The model can be loaded with the following code snippet.
```python
from diffusers import SanaTransformer2DModel
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
```
## SanaTransformer2DModel
[[autodoc]] SanaTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaPipeline
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
The abstract from the paper is:
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model]https://huggingface.co/Efficient-Large-Model).
Available models:
| Model | Recommended dtype |
|:-----:|:-----------------:|
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
| [`Efficient-Large-Model/Sana_600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px_diffusers) | `torch.float16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information.
<Tip>
Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
</Tip>
## SanaPipeline
[[autodoc]] SanaPipeline
- all
- __call__
## SanaPAGPipeline
[[autodoc]] SanaPAGPipeline
- all
- __call__
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
#!/usr/bin/env python
from __future__ import annotations
import argparse
import os
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderDC,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaTransformer2DModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
"Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth",
"Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth",
]
# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
def main(args):
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
snapshot_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
cache_dir=cache_dir_path,
repo_type="model",
)
file_path = hf_hub_download(
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
cache_dir=cache_dir_path,
repo_type="model",
)
else:
file_path = args.orig_ckpt_path
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
all_state_dict = torch.load(file_path, weights_only=True)
state_dict = all_state_dict.pop("state_dict")
converted_state_dict = {}
# Patch embeddings.
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
# Caption projection.
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
# y norm
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
flow_shift = 3.0
if args.model_type == "SanaMS_1600M_P1_D20":
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
layer_num = 28
else:
raise ValueError(f"{args.model_type} is not supported.")
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
f"blocks.{depth}.scale_shift_table"
)
# Linear Attention is all you need 🤘
# Self attention.
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.attn.proj.bias"
)
# Feed-forward.
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
f"blocks.{depth}.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
f"blocks.{depth}.mlp.point_conv.conv.weight"
)
# Cross-attention.
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.bias"
)
# Final block.
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
# Transformer
with CTX():
transformer = SanaTransformer2DModel(
in_channels=32,
out_channels=32,
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
mlp_ratio=2.5,
attention_bias=False,
sample_size=args.image_size // 32,
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
else:
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
transformer = transformer.to(weight_dtype)
if not args.save_full_pipeline:
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole SanaPipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
)
else:
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[512, 1024],
required=False,
help="Image size of pretrained model, 512 or 1024.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
)
parser.add_argument(
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args()
model_kwargs = {
"SanaMS_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
"SanaMS_600M_P1_D28": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)
...@@ -114,6 +114,7 @@ else: ...@@ -114,6 +114,7 @@ else:
"MultiControlNetModel", "MultiControlNetModel",
"PixArtTransformer2DModel", "PixArtTransformer2DModel",
"PriorTransformer", "PriorTransformer",
"SanaTransformer2DModel",
"SD3ControlNetModel", "SD3ControlNetModel",
"SD3MultiControlNetModel", "SD3MultiControlNetModel",
"SD3Transformer2DModel", "SD3Transformer2DModel",
...@@ -332,6 +333,8 @@ else: ...@@ -332,6 +333,8 @@ else:
"PixArtSigmaPAGPipeline", "PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline", "PixArtSigmaPipeline",
"ReduxImageEncoder", "ReduxImageEncoder",
"SanaPAGPipeline",
"SanaPipeline",
"SemanticStableDiffusionPipeline", "SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline", "ShapEImg2ImgPipeline",
"ShapEPipeline", "ShapEPipeline",
...@@ -345,6 +348,7 @@ else: ...@@ -345,6 +348,7 @@ else:
"StableDiffusion3Img2ImgPipeline", "StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline", "StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGImg2ImgPipeline", "StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3PAGPipeline", "StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline", "StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline", "StableDiffusionAdapterPipeline",
...@@ -616,6 +620,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -616,6 +620,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiControlNetModel, MultiControlNetModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
SanaTransformer2DModel,
SD3ControlNetModel, SD3ControlNetModel,
SD3MultiControlNetModel, SD3MultiControlNetModel,
SD3Transformer2DModel, SD3Transformer2DModel,
...@@ -813,6 +818,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -813,6 +818,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
PixArtSigmaPipeline, PixArtSigmaPipeline,
ReduxImageEncoder, ReduxImageEncoder,
SanaPAGPipeline,
SanaPipeline,
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline, ShapEImg2ImgPipeline,
ShapEPipeline, ShapEPipeline,
......
...@@ -60,6 +60,7 @@ if is_torch_available(): ...@@ -60,6 +60,7 @@ if is_torch_available():
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
...@@ -135,6 +136,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -135,6 +136,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MochiTransformer3DModel, MochiTransformer3DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
SanaTransformer2DModel,
SD3Transformer2DModel, SD3Transformer2DModel,
StableAudioDiTModel, StableAudioDiTModel,
T5FilmDecoder, T5FilmDecoder,
......
...@@ -5441,6 +5441,165 @@ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0): ...@@ -5441,6 +5441,165 @@ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
super().__init__() super().__init__()
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
query = F.relu(query)
key = F.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states = torch.matmul(scores, query)
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
hidden_states = hidden_states.to(original_dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if original_dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class PAGCFGSanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
query = F.relu(query)
key = F.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states_org = torch.matmul(scores, query)
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
hidden_states_org = hidden_states_org.to(original_dtype)
hidden_states_org = attn.to_out[0](hidden_states_org)
hidden_states_org = attn.to_out[1](hidden_states_org)
# perturbed path (identity attention)
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if original_dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
class PAGIdentitySanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = hidden_states.dtype
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
query = F.relu(query)
key = F.relu(key)
query, key, value = query.float(), key.float(), value.float()
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
scores = torch.matmul(value, key)
hidden_states_org = torch.matmul(scores, query)
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
hidden_states_org = hidden_states_org.float()
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
hidden_states_org = hidden_states_org.to(original_dtype)
hidden_states_org = attn.to_out[0](hidden_states_org)
hidden_states_org = attn.to_out[1](hidden_states_org)
# perturbed path (identity attention)
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
if original_dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
ADDED_KV_ATTENTION_PROCESSORS = ( ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
...@@ -5493,6 +5652,12 @@ AttentionProcessor = Union[ ...@@ -5493,6 +5652,12 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor2_0, CustomDiffusionAttnProcessor2_0,
SlicedAttnProcessor, SlicedAttnProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
SanaLinearAttnProcessor2_0,
PAGCFGSanaLinearAttnProcessor2_0,
PAGIdentitySanaLinearAttnProcessor2_0,
SanaMultiscaleLinearAttention,
SanaMultiscaleAttnProcessor2_0,
SanaMultiscaleAttentionProjection,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor, IPAdapterXFormersAttnProcessor,
......
...@@ -26,39 +26,10 @@ from ..activations import get_activation ...@@ -26,39 +26,10 @@ from ..activations import get_activation
from ..attention_processor import SanaMultiscaleLinearAttention from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
from .vae import DecoderOutput, EncoderOutput from .vae import DecoderOutput, EncoderOutput
class GLUMBConv(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
hidden_channels = 4 * in_channels
self.nonlinearity = nn.SiLU()
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
return hidden_states + residual
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__( def __init__(
self, self,
...@@ -115,6 +86,7 @@ class EfficientViTBlock(nn.Module): ...@@ -115,6 +86,7 @@ class EfficientViTBlock(nn.Module):
self.conv_out = GLUMBConv( self.conv_out = GLUMBConv(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
norm_type="rms_norm",
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
......
...@@ -11,6 +11,7 @@ if is_torch_available(): ...@@ -11,6 +11,7 @@ if is_torch_available():
from .lumina_nextdit2d import LuminaNextDiT2DModel from .lumina_nextdit2d import LuminaNextDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer from .prior_transformer import PriorTransformer
from .sana_transformer import SanaTransformer2DModel
from .stable_audio_transformer import StableAudioDiTModel from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel from .transformer_2d import Transformer2DModel
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class GLUMBConv(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: float = 4,
norm_type: Optional[str] = None,
residual_connection: bool = True,
) -> None:
super().__init__()
hidden_channels = int(expand_ratio * in_channels)
self.norm_type = norm_type
self.residual_connection = residual_connection
self.nonlinearity = nn.SiLU()
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
self.norm = None
if norm_type == "rms_norm":
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.residual_connection:
residual = hidden_states
hidden_states = self.conv_inverted(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_depth(hidden_states)
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
hidden_states = hidden_states * self.nonlinearity(gate)
hidden_states = self.conv_point(hidden_states)
if self.norm_type == "rms_norm":
# move channel to the last dimension so we apply RMSnorm across channel dimension
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.residual_connection:
hidden_states = hidden_states + residual
return hidden_states
class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
"""
def __init__(
self,
dim: int = 2240,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
dropout: float = 0.0,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
attention_bias: bool = True,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
attention_out_bias: bool = True,
mlp_ratio: float = 2.5,
) -> None:
super().__init__()
# 1. Self Attention
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=SanaLinearAttnProcessor2_0(),
)
# 2. Cross Attention
if cross_attention_dim is not None:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_cross_attention_heads,
dim_head=cross_attention_head_dim,
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=AttnProcessor2_0(),
)
# 3. Feed-forward
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
height: int = None,
width: int = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
attn_output = self.attn1(norm_hidden_states)
hidden_states = hidden_states + gate_msa * attn_output
# 3. Cross Attention
if self.attn2 is not None:
attn_output = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
ff_output = self.ff(norm_hidden_states)
ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
hidden_states = hidden_states + gate_mlp * ff_output
return hidden_states
class SanaTransformer2DModel(ModelMixin, ConfigMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
Args:
in_channels (`int`, defaults to `32`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `32`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `70`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `32`):
The number of channels in each head.
num_layers (`int`, defaults to `20`):
The number of layers of Transformer blocks to use.
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
The number of heads to use for cross-attention.
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
The number of channels in each head for cross-attention.
cross_attention_dim (`int`, *optional*, defaults to `2240`):
The number of channels in the cross-attention output.
caption_channels (`int`, defaults to `2304`):
The number of channels in the caption embeddings.
mlp_ratio (`float`, defaults to `2.5`):
The expansion ratio to use in the GLUMBConv layer.
dropout (`float`, defaults to `0.0`):
The dropout probability.
attention_bias (`bool`, defaults to `False`):
Whether to use bias in the attention layer.
sample_size (`int`, defaults to `32`):
The base size of the input latent.
patch_size (`int`, defaults to `1`):
The size of the patches to use in the patch embedding layer.
norm_elementwise_affine (`bool`, defaults to `False`):
Whether to use elementwise affinity in the normalization layer.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value for the normalization layer.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
@register_to_config
def __init__(
self,
in_channels: int = 32,
out_channels: Optional[int] = 32,
num_attention_heads: int = 70,
attention_head_dim: int = 32,
num_layers: int = 20,
num_cross_attention_heads: Optional[int] = 20,
cross_attention_head_dim: Optional[int] = 112,
cross_attention_dim: Optional[int] = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=None,
pos_embed_type=None,
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
SanaTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
)
for _ in range(num_layers)
]
)
# 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size, num_channels, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height, post_patch_width = height // p, width // p
hidden_states = self.patch_embed(hidden_states)
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
# 2. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
for block in self.transformer_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
**ckpt_kwargs,
)
else:
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
post_patch_height,
post_patch_width,
)
# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
)
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -185,6 +185,7 @@ else: ...@@ -185,6 +185,7 @@ else:
"StableDiffusionXLControlNetPAGPipeline", "StableDiffusionXLControlNetPAGPipeline",
"StableDiffusionXLPAGImg2ImgPipeline", "StableDiffusionXLPAGImg2ImgPipeline",
"PixArtSigmaPAGPipeline", "PixArtSigmaPAGPipeline",
"SanaPAGPipeline",
] ]
) )
_import_structure["controlnet_xs"].extend( _import_structure["controlnet_xs"].extend(
...@@ -263,6 +264,7 @@ else: ...@@ -263,6 +264,7 @@ else:
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"] _import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["sana"] = ["SanaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [ _import_structure["stable_audio"] = [
...@@ -599,6 +601,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -599,6 +601,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
KolorsPAGPipeline, KolorsPAGPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
SanaPAGPipeline,
StableDiffusion3PAGImg2ImgPipeline, StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline, StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline, StableDiffusionControlNetPAGInpaintPipeline,
...@@ -615,6 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -615,6 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .sana import SanaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
......
...@@ -29,6 +29,7 @@ else: ...@@ -29,6 +29,7 @@ else:
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"] _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"] _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"] _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sana"] = ["SanaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"] _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"] _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"] _import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"]
...@@ -55,6 +56,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -55,6 +56,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
from .pipeline_pag_kolors import KolorsPAGPipeline from .pipeline_pag_kolors import KolorsPAGPipeline
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sana import SanaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline
......
This diff is collapsed.
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana"] = ["SanaPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_sana import SanaPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class SanaPipelineOutput(BaseOutput):
"""
Output class for Sana pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
This diff is collapsed.
...@@ -149,6 +149,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -149,6 +149,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
steps_offset: int = 0, steps_offset: int = 0,
): ):
...@@ -282,6 +284,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -282,6 +284,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
timesteps = (sigmas * self.config.num_train_timesteps).copy()
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
...@@ -362,8 +369,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -362,8 +369,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5) if self.config.use_flow_sigmas:
sigma_t = sigma * alpha_t alpha_t = 1 - sigma
sigma_t = sigma
else:
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t return alpha_t, sigma_t
...@@ -490,10 +501,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -490,10 +501,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
x0_pred = model_output x0_pred = model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" `v_prediction` for the DEISMultistepScheduler." "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler."
) )
if self.config.thresholding: if self.config.thresholding:
......
...@@ -218,6 +218,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -218,6 +218,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False, use_lu_lambdas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
...@@ -407,6 +409,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -407,6 +409,11 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy() sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
timesteps = (sigmas * self.config.num_train_timesteps).copy()
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
...@@ -495,8 +502,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -495,8 +502,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return t return t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5) if self.config.use_flow_sigmas:
sigma_t = sigma * alpha_t alpha_t = 1 - sigma
sigma_t = sigma
else:
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t return alpha_t, sigma_t
...@@ -650,10 +661,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -650,10 +661,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" `v_prediction` for the DPMSolverMultistepScheduler." "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
) )
if self.config.thresholding: if self.config.thresholding:
......
...@@ -169,6 +169,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -169,6 +169,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
timestep_spacing: str = "linspace", timestep_spacing: str = "linspace",
...@@ -292,6 +294,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -292,6 +294,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
elif self.config.use_beta_sigmas: elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
timesteps = (sigmas * self.config.num_train_timesteps).copy()
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = ( sigma_max = (
...@@ -379,8 +386,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -379,8 +386,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5) if self.config.use_flow_sigmas:
sigma_t = sigma * alpha_t alpha_t = 1 - sigma
sigma_t = sigma
else:
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t return alpha_t, sigma_t
...@@ -522,10 +533,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): ...@@ -522,10 +533,13 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" `v_prediction` for the DPMSolverMultistepScheduler." "`v_prediction`, or `flow_prediction` for the DPMSolverMultistepScheduler."
) )
if self.config.thresholding: if self.config.thresholding:
......
...@@ -164,6 +164,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -164,6 +164,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas: Optional[bool] = False, use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"), lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None, variance_type: Optional[str] = None,
...@@ -356,6 +358,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -356,6 +358,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy() sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1]
timesteps = (sigmas * self.config.num_train_timesteps).copy()
else: else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
...@@ -454,8 +461,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -454,8 +461,12 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma):
alpha_t = 1 / ((sigma**2 + 1) ** 0.5) if self.config.use_flow_sigmas:
sigma_t = sigma * alpha_t alpha_t = 1 - sigma
sigma_t = sigma
else:
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t return alpha_t, sigma_t
...@@ -595,10 +606,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -595,10 +606,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
x0_pred = alpha_t * sample - sigma_t * model_output x0_pred = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
sigma_t = self.sigmas[self.step_index]
x0_pred = sample - sigma_t * model_output
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
" `v_prediction` for the DPMSolverSinglestepScheduler." "`v_prediction`, or `flow_prediction` for the DPMSolverSinglestepScheduler."
) )
if self.config.thresholding: if self.config.thresholding:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment