Unverified Commit 98388670 authored by PommesPeter's avatar PommesPeter Committed by GitHub
Browse files

[Alpha-VLLM Team] Add Lumina-T2X to diffusers (#8652)




---------
Co-authored-by: default avatarzhuole1025 <zhuole1025@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 9e9ed353
......@@ -249,6 +249,8 @@
title: DiTTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/sd3_transformer2d
......@@ -324,6 +326,8 @@
title: Latent Diffusion
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/panorama
......@@ -435,6 +439,8 @@
title: EulerDiscreteScheduler
- local: api/schedulers/flow_match_euler_discrete
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/flow_match_heun_discrete
title: FlowMatchHeunDiscreteScheduler
- local: api/schedulers/heun
title: HeunDiscreteScheduler
- local: api/schedulers/ipndm
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LuminaNextDiT2DModel
A Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X).
## LuminaNextDiT2DModel
[[autodoc]] LuminaNextDiT2DModel
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Lumina-T2X
![concepts](https://github.com/Alpha-VLLM/Lumina-T2X/assets/54879512/9f52eabb-07dc-4881-8257-6d8a5f2a0a5a)
[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
The abstract from the paper is:
*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.*
**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements.
Lumina-Next has the following components:
* It improves sampling efficiency with fewer and faster Steps.
* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention.
* It uses a Frequency- and Time-Aware Scaled RoPE.
---
[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://arxiv.org/abs/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
The abstract from the paper is:
*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.*
You can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b).
**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration.
Lumina-T2X has the following components:
* It uses a Flow-based Large Diffusion Transformer as the backbone
* It supports different any modalities with one backbone and corresponding encoder, decoder.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
### Inference (Text-to-Image)
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
First, load the pipeline:
```python
from diffusers import LuminaText2ImgPipeline
import torch
pipeline = LuminaText2ImgPipeline.from_pretrained(
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
).to("cuda")
```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
```
Finally, compile the components and run inference:
```python
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
```
## LuminaText2ImgPipeline
[[autodoc]] LuminaText2ImgPipeline
- all
- __call__
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FlowMatchHeunDiscreteScheduler
`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://arxiv.org/abs/2403.03206).
## FlowMatchHeunDiscreteScheduler
[[autodoc]] FlowMatchHeunDiscreteScheduler
import argparse
import os
import torch
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
def main(args):
# checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
all_sd = load_file(args.origin_ckpt_path, device="cpu")
converted_state_dict = {}
# pad token
converted_state_dict["pad_token"] = all_sd["pad_token"]
# patch embed
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]
# time and caption embed
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]
for i in range(24):
# adaln
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]
# qkv
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]
# cap
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]
# output
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]
# attention
# qk norm
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]
converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]
# attention norm
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]
# feed forward
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]
# feed forward norm
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]
# final layer
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]
# Lumina-Next-SFT 2B
transformer = LuminaNextDiT2DModel(
sample_size=128,
patch_size=2,
in_channels=4,
hidden_size=2304,
num_layers=24,
num_attention_heads=32,
num_kv_heads=8,
multiple_of=256,
ffn_dim_multiplier=None,
norm_eps=1e-5,
learn_sigma=True,
qk_norm=True,
cross_attention_dim=2048,
scaling_factor=1.0,
)
transformer.load_state_dict(converted_state_dict, strict=True)
num_model_params = sum(p.numel() for p in transformer.parameters())
print(f"Total number of transformer parameters: {num_model_params}")
if args.only_transformer:
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
else:
scheduler = FlowMatchEulerDiscreteScheduler()
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
pipeline = LuminaText2ImgPipeline(
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
)
pipeline.save_pretrained(args.dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--image_size",
default=1024,
type=int,
choices=[256, 512, 1024],
required=False,
help="Image size of pretrained model, either 512 or 1024.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
args = parser.parse_args()
main(args)
......@@ -88,6 +88,7 @@ else:
"HunyuanDiT2DMultiControlNetModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
"LuminaNextDiT2DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
......@@ -162,6 +163,7 @@ else:
"EulerAncestralDiscreteScheduler",
"EulerDiscreteScheduler",
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
......@@ -270,6 +272,7 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MusicLDMPipeline",
......@@ -509,6 +512,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DMultiControlNetModel,
I2VGenXLUNet,
Kandinsky3UNet,
LuminaNextDiT2DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
......@@ -580,6 +584,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
......@@ -669,6 +674,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MusicLDMPipeline,
......
......@@ -41,6 +41,7 @@ if is_torch_available():
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
......@@ -85,6 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DiTTransformer2DModel,
DualTransformer2DModel,
HunyuanDiT2DModel,
LuminaNextDiT2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,
......
......@@ -19,7 +19,7 @@ from torch import nn
from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
......@@ -527,6 +527,56 @@ class BasicTransformerBlock(nn.Module):
return hidden_states
class LuminaFeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
dimension. Defaults to None.
"""
def __init__(
self,
dim: int,
inner_dim: int,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
inner_dim = int(2 * inner_dim / 3)
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.linear_2 = nn.Linear(
inner_dim,
dim,
bias=False,
)
self.linear_3 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.silu = FP32SiLU()
def forward(self, x):
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
......
......@@ -94,6 +94,7 @@ class Attention(nn.Module):
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
......@@ -118,6 +119,7 @@ class Attention(nn.Module):
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
......@@ -168,6 +170,10 @@ class Attention(nn.Module):
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
elif qk_norm == "layer_norm_across_heads":
# Lumina applys qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
else:
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
......@@ -198,15 +204,15 @@ class Attention(nn.Module):
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
......@@ -1594,6 +1600,102 @@ class HunyuanAttnProcessor2_0:
return hidden_states
class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from .embeddings import apply_rotary_emb
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype
# Get key-value heads
kv_heads = inner_dim // head_dim
# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)
# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
query, key = query.to(dtype), key.to(dtype)
# Apply proportional attention if true
if key_rotary_emb is None:
softmax_scale = None
else:
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale
# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).to(dtype)
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
......
......@@ -230,6 +230,52 @@ class PatchEmbed(nn.Module):
return (latent + pos_embed).to(latent.dtype)
class LuminaPatchEmbed(nn.Module):
"""2D Image to Patch Embedding with support for Lumina-T2X"""
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=embed_dim,
bias=bias,
)
def forward(self, x, freqs_cis):
"""
Patchifies and embeds the input tensor(s).
Args:
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
Returns:
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
frequency tensor(s).
"""
freqs_cis = freqs_cis.to(x[0].device)
patch_height = patch_width = self.patch_size
batch_size, channel, height, width = x.size()
height_tokens, width_tokens = height // patch_height, width // patch_width
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
0, 2, 4, 1, 3, 5
)
x = x.flatten(3)
x = self.proj(x)
x = x.flatten(1, 2)
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
return (
x,
mask,
[(height, width)] * batch_size,
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
)
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
......@@ -274,7 +320,25 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
assert embed_dim % 4 == 0
emb_h = get_1d_rotary_pos_embed(
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (H, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
) # (W, D/4)
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
......@@ -289,13 +353,17 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
......@@ -310,6 +378,7 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
......@@ -325,16 +394,23 @@ def apply_rotary_emb(
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class TimestepEmbedding(nn.Module):
......@@ -778,6 +854,40 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
return conditioning
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
super().__init__()
self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
self.caption_embedder = nn.Sequential(
nn.LayerNorm(cross_attention_dim),
nn.Linear(
cross_attention_dim,
hidden_size,
bias=True,
),
)
def forward(self, timestep, caption_feat, caption_mask):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
# caption condition embedding:
caption_mask_float = caption_mask.float().unsqueeze(-1)
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
caption_feats_pool = caption_feats_pool.to(caption_feat)
caption_embed = self.caption_embedder(caption_feats_pool)
conditioning = time_embed + caption_embed
return conditioning
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
......
......@@ -22,7 +22,10 @@ import torch.nn.functional as F
from ..utils import is_torch_version
from .activations import get_activation
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
from .embeddings import (
CombinedTimestepLabelEmbeddings,
PixArtAlphaCombinedTimestepSizeEmbeddings,
)
class AdaLayerNorm(nn.Module):
......@@ -84,6 +87,37 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(
min(embedding_dim, 1024),
4 * embedding_dim,
bias=True,
)
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
......@@ -188,6 +222,54 @@ class AdaLayerNormContinuous(nn.Module):
return x
class LuminaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
out_dim: Optional[int] = None,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
else:
raise ValueError(f"unknown norm_type {norm_type}")
# linear_2
if out_dim is not None:
self.linear_2 = nn.Linear(
embedding_dim,
out_dim,
bias=bias,
)
def forward(
self,
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
scale = emb
x = self.norm(x) * (1 + scale)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
......
......@@ -5,6 +5,7 @@ if is_torch_available():
from .dit_transformer_2d import DiTTransformer2DModel
from .dual_transformer_2d import DualTransformer2DModel
from .hunyuan_transformer_2d import HunyuanDiT2DModel
from .lumina_nextdit2d import LuminaNextDiT2DModel
from .pixart_transformer_2d import PixArtTransformer2DModel
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
......
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention, LuminaAttnProcessor2_0
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding,
LuminaPatchEmbed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LuminaNextDiTBlock(nn.Module):
"""
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
Parameters:
dim (`int`): Embedding dimension of the input features.
num_attention_heads (`int`): Number of attention heads.
num_kv_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
multiple_of (`int`): The number of multiple of ffn layer.
ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
norm_eps (`float`): The eps for norm layer.
qk_norm (`bool`): normalization for query and key.
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
norm_elementwise_affine (`bool`, *optional*, defaults to True),
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
num_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
cross_attention_dim: int,
norm_elementwise_affine: bool = True,
) -> None:
super().__init__()
self.head_dim = dim // num_attention_heads
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
# Self-attention
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
qk_norm="layer_norm_across_heads" if qk_norm else None,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.attn1.to_out = nn.Identity()
# Cross-attention
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
qk_norm="layer_norm_across_heads" if qk_norm else None,
heads=num_attention_heads,
kv_heads=num_kv_heads,
eps=1e-5,
bias=False,
out_bias=False,
processor=LuminaAttnProcessor2_0(),
)
self.feed_forward = LuminaFeedForward(
dim=dim,
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.norm1 = LuminaRMSNormZero(
embedding_dim=dim,
norm_eps=norm_eps,
norm_elementwise_affine=norm_elementwise_affine,
)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Perform a forward pass through the LuminaNextDiTBlock.
Parameters:
hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
"""
residual = hidden_states
# Self-attention
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
self_attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
attention_mask=attention_mask,
query_rotary_emb=image_rotary_emb,
key_rotary_emb=image_rotary_emb,
**cross_attention_kwargs,
)
# Cross-attention
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
cross_attn_output = self.attn2(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=encoder_mask,
query_rotary_emb=image_rotary_emb,
key_rotary_emb=None,
**cross_attention_kwargs,
)
cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
mixed_attn_output = self_attn_output + cross_attn_output
mixed_attn_output = mixed_attn_output.flatten(-2)
# linear proj
hidden_states = self.attn2.to_out[0](mixed_attn_output)
hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
return hidden_states
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
"""
LuminaNextDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters:
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
in_channels (`int`, *optional*, defaults to 4):
The number of input channels for the model. Typically, this matches the number of channels in the input
images.
hidden_size (`int`, *optional*, defaults to 4096):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
num_layers (`int`, *optional*, default to 32):
The number of layers in the model. This defines the depth of the neural network.
num_attention_heads (`int`, *optional*, defaults to 32):
The number of attention heads in each attention layer. This parameter specifies how many separate attention
mechanisms are used.
num_kv_heads (`int`, *optional*, defaults to 8):
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
If None, it defaults to num_attention_heads.
multiple_of (`int`, *optional*, defaults to 256):
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
configurations.
ffn_dim_multiplier (`float`, *optional*):
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
the model configuration.
norm_eps (`float`, *optional*, defaults to 1e-5):
A small value added to the denominator for numerical stability in normalization layers.
learn_sigma (`bool`, *optional*, defaults to True):
Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
predictions.
qk_norm (`bool`, *optional*, defaults to True):
Indicates if the queries and keys in the attention mechanism should be normalized.
cross_attention_dim (`int`, *optional*, defaults to 2048):
The dimensionality of the text embeddings. This parameter defines the size of the text representations used
in the model.
scaling_factor (`float`, *optional*, defaults to 1.0):
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
overall scale of the model's operations.
"""
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: Optional[int] = 2,
in_channels: Optional[int] = 4,
hidden_size: Optional[int] = 2304,
num_layers: Optional[int] = 32,
num_attention_heads: Optional[int] = 32,
num_kv_heads: Optional[int] = None,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: Optional[float] = 1e-5,
learn_sigma: Optional[bool] = True,
qk_norm: Optional[bool] = True,
cross_attention_dim: Optional[int] = 2048,
scaling_factor: Optional[float] = 1.0,
) -> None:
super().__init__()
self.sample_size = sample_size
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.scaling_factor = scaling_factor
self.patch_embedder = LuminaPatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
)
self.pad_token = nn.Parameter(torch.empty(hidden_size))
self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
)
self.layers = nn.ModuleList(
[
LuminaNextDiTBlock(
hidden_size,
num_attention_heads,
num_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
cross_attention_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
bias=True,
out_dim=patch_size * patch_size * self.out_channels,
)
# self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_mask: torch.Tensor,
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
) -> torch.Tensor:
"""
Forward pass of LuminaNextDiT.
Parameters:
hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
"""
hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
image_rotary_emb = image_rotary_emb.to(hidden_states.device)
temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
encoder_mask = encoder_mask.bool()
for layer in self.layers:
hidden_states = layer(
hidden_states,
mask,
image_rotary_emb,
encoder_hidden_states,
encoder_mask,
temb=temb,
cross_attention_kwargs=cross_attention_kwargs,
)
hidden_states = self.norm_out(hidden_states, temb)
# unpatchify
height_tokens = width_tokens = self.patch_size
height, width = img_size[0]
batch_size = hidden_states.size(0)
sequence_length = (height // height_tokens) * (width // width_tokens)
hidden_states = hidden_states[:, :sequence_length].view(
batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
)
output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -209,6 +209,7 @@ else:
"LEditsPPPipelineStableDiffusionXL",
]
)
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
......@@ -486,6 +487,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .lumina import LuminaText2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_lumina"] = ["LuminaText2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_lumina import LuminaText2ImgPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
......@@ -57,6 +57,7 @@ else:
_import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"]
_import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
......@@ -153,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
......
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Heun scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 2
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
self.timesteps = timesteps.to(device=device)
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
# empty dt and derivative
self.prev_derivative = None
self.dt = None
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
@property
def state_in_first_order(self):
return self.dt is None
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if self.state_in_first_order:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
else:
# 2nd order / Heun's method
sigma = self.sigmas[self.step_index - 1]
sigma_next = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
if self.state_in_first_order:
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
denoised = sample - model_output * sigma
# 2. convert to an ODE derivative for 1st order
derivative = (sample - denoised) / sigma_hat
# 3. Delta timestep
dt = sigma_next - sigma_hat
# store for 2nd order step
self.prev_derivative = derivative
self.dt = dt
self.sample = sample
else:
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
denoised = sample - model_output * sigma_next
# 2. 2nd order / Heun's method
derivative = (sample - denoised) / sigma_next
derivative = 0.5 * (self.prev_derivative + derivative)
# 3. take prev timestep & sample
dt = self.dt
sample = self.sample
# free dt and derivative
# Note, this puts the scheduler in "first order mode"
self.prev_derivative = None
self.dt = None
self.sample = None
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
......@@ -197,6 +197,21 @@ class Kandinsky3UNet(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LuminaNextDiT2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
......@@ -1095,6 +1110,21 @@ class FlowMatchEulerDiscreteScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FlowMatchHeunDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HeunDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -722,6 +722,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LuminaText2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class MarigoldDepthPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment