Unverified Commit 96c376a5 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] LTX Video (#10021)



* transformer

* make style & make fix-copies

* transformer

* add transformer tests

* 80% vae

* make style

* make fix-copies

* fix

* undo cogvideox changes

* update

* update

* match vae

* add docs

* t2v pipeline working; scheduler needs to be checked

* docs

* add pipeline test

* update

* update

* make fix-copies

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* copy t2v to i2v pipeline

* update

* apply review suggestions

* update

* make style

* remove framewise encoding/decoding

* pack/unpack latents

* image2video

* update

* make fix-copies

* update

* update

* rope scale fix

* debug layerwise code

* remove debug

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* propagate precision changes to i2v pipeline

* remove downcast

* address review comments

* fix comment

* address review comments

* [Single File] LTX support for loading original weights (#10135)

* from original file mixin for ltx

* undo config mapping fn changes

* update

* add single file to pipelines

* update docs

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

* Update src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

* rename classes based on ltx review

* point to original repository for inference

* make style

* resolve conflicts correctly

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 8170dc36
......@@ -274,6 +274,8 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
......@@ -312,6 +314,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
......@@ -408,6 +412,8 @@
title: Latte
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTXVideo
The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTXVideo
[[autodoc]] AutoencoderKLLTXVideo
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTXVideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTXVideoTransformer3DModel
[[autodoc]] LTXVideoTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
# ... inference code ...
```
Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`].
```python
import torch
from diffusers import LTXImageToVideoPipeline
from transformers import T5EncoderModel, T5Tokenizer
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16)
tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16)
pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
```
## LTXPipeline
[[autodoc]] LTXPipeline
- all
- __call__
## LTXImageToVideoPipeline
[[autodoc]] LTXImageToVideoPipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
import argparse
from typing import Any, Dict
import torch
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
TOKENIZER_MAX_LENGTH = 128
TRANSFORMER_KEYS_RENAME_DICT = {
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VAE_KEYS_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""
original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
return transformer
def convert_vae(ckpt_path: str, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
parser.add_argument(
"--typecast_text_encoder",
action="store_true",
default=False,
help="Whether or not to apply fp16/bf16 precision to text_encoder",
)
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
if args.transformer_ckpt_path is not None:
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
if args.typecast_text_encoder:
text_encoder = text_encoder.to(dtype=dtype)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
......@@ -84,6 +84,7 @@ else:
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLLTXVideo",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
......@@ -104,6 +105,7 @@ else:
"I2VGenXLUNet",
"Kandinsky3UNet",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
......@@ -317,6 +319,8 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXImageToVideoPipeline",
"LTXPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
......@@ -582,6 +586,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLLTXVideo,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
......@@ -602,6 +607,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
I2VGenXLUNet,
Kandinsky3UNet,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
......@@ -794,6 +800,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXImageToVideoPipeline,
LTXPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
......
......@@ -28,6 +28,8 @@ from .single_file_utils import (
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
......@@ -83,6 +85,14 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTXVideo": {
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
}
......
......@@ -92,6 +92,12 @@ CHECKPOINT_KEY_NAMES = {
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"ltx-video": [
(
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
),
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
}
......@@ -140,6 +146,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
......@@ -571,6 +578,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "flux-schnell"
elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
decoder_key = "decoder.project_in.main.conv.weight"
......@@ -2223,6 +2233,96 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
return converted_state_dict
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
}
TRANSFORMER_KEYS_RENAME_DICT = {
"model.diffusion_model.": "",
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
def remove_keys_(key: str, state_dict):
state_dict.pop(key)
VAE_KEYS_RENAME_DICT = {
# common
"vae.": "",
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
}
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
......
......@@ -31,6 +31,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
......@@ -65,6 +66,7 @@ if is_torch_available():
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
......@@ -94,6 +96,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLLTXVideo,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
......@@ -127,6 +130,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HunyuanDiT2DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
......
......@@ -199,12 +199,16 @@ class Attention(nn.Module):
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "layer_norm_across_heads":
# Lumina applys qk norm across all heads
# Lumina applies qk norm across all heads
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "l2":
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
......
......@@ -3,6 +3,7 @@ from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
......
# Copyright 2024 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution
class LTXCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
padding_mode: str = "zeros",
is_causal: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.is_causal = is_causal
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
height_pad = self.kernel_size[1] // 2
width_pad = self.kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d(
in_channels,
out_channels,
self.kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
padding=padding,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
time_kernel_size = self.kernel_size[0]
if self.is_causal:
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1))
hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)
else:
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1))
hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2)
hidden_states = self.conv(hidden_states)
return hidden_states
class LTXResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
dropout (`float`, defaults to `0.0`):
Dropout rate.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
elementwise_affine (`bool`, defaults to `False`):
Whether to enable elementwise affinity in the normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
eps: float = 1e-6,
elementwise_affine: bool = False,
non_linearity: str = "swish",
is_causal: bool = True,
):
super().__init__()
out_channels = out_channels or in_channels
self.nonlinearity = get_activation(non_linearity)
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
self.norm3 = None
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
self.conv_shortcut = LTXCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.norm3 is not None:
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
if self.conv_shortcut is not None:
inputs = self.conv_shortcut(inputs)
hidden_states = hidden_states + inputs
return hidden_states
class LTXUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
out_channels = in_channels * stride[0] * stride[1] * stride[2]
self.conv = LTXCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
return hidden_states
class LTXDownBlock3D(nn.Module):
r"""
Down block used in the LTX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList(
[
LTXCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
]
)
self.conv_out = None
if in_channels != out_channels:
self.conv_out = LTXResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
else:
hidden_states = resnet(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if self.conv_out is not None:
hidden_states = self.conv_out(hidden_states)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXMidBlock3d(nn.Module):
r"""
A middle block used in the LTX model.
Args:
in_channels (`int`):
Number of input channels.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
is_causal: bool = True,
) -> None:
super().__init__()
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `LTXMidBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
else:
hidden_states = resnet(hidden_states)
return hidden_states
class LTXUpBlock3d(nn.Module):
r"""
Up block used in the LTX model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
):
super().__init__()
out_channels = out_channels or in_channels
self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.conv_in is not None:
hidden_states = self.conv_in(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
else:
hidden_states = resnet(hidden_states)
return hidden_states
class LTXEncoder3d(nn.Module):
r"""
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
representation.
Args:
in_channels (`int`, defaults to 3):
Number of input channels.
out_channels (`int`, defaults to 128):
Number of latent channels.
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal downscaling layers or not.
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches.
resnet_norm_eps (`float`, defaults to `1e-6`):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.in_channels = in_channels * patch_size**2
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
is_causal=is_causal,
)
# down blocks
num_block_out_channels = len(block_out_channels)
self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
down_block = LTXDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
self.down_blocks.append(down_block)
# mid block
self.mid_block = LTXMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
is_causal=is_causal,
)
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXEncoder3D` class."""
p = self.patch_size
p_t = self.patch_size_t
batch_size, num_channels, num_frames, height, width = hidden_states.shape
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
hidden_states = hidden_states.reshape(
batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
)
# Thanks for driving me insane with the weird patching order :(
hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = self.mid_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
last_channel = hidden_states[:, -1:]
last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
hidden_states = torch.cat([hidden_states, last_channel], dim=1)
return hidden_states
class LTXDecoder3d(nn.Module):
r"""
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, defaults to 128):
Number of latent channels.
out_channels (`int`, defaults to 3):
Number of output channels.
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal upscaling layers or not.
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches.
resnet_norm_eps (`float`, defaults to `1e-6`):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `False`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
def __init__(
self,
in_channels: int = 128,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.out_channels = out_channels * patch_size**2
block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
layers_per_block = tuple(reversed(layers_per_block))
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
)
self.mid_block = LTXMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
)
# up blocks
num_block_out_channels = len(block_out_channels)
self.up_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i]
up_block = LTXUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
self.up_blocks.append(up_block)
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
else:
hidden_states = self.mid_block(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
p = self.patch_size
p_t = self.patch_size_t
batch_size, num_channels, num_frames, height, width = hidden_states.shape
hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width)
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return hidden_states
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Args:
in_channels (`int`, defaults to `3`):
Number of input channels.
out_channels (`int`, defaults to `3`):
Number of output channels.
latent_channels (`int`, defaults to `128`):
Number of latent channels.
block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
The number of output channels for each block.
spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`:
Whether a block should contain spatio-temporal downscaling or not.
layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`):
The number of layers per block.
patch_size (`int`, defaults to `4`):
The size of spatial patches.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches.
resnet_norm_eps (`float`, defaults to `1e-6`):
Epsilon value for ResNet normalization layers.
scaling_factor (`float`, *optional*, defaults to `1.0`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
encoder_causal (`bool`, defaults to `True`):
Whether the encoder should behave causally (future frames depend only on past frames) or not.
decoder_causal (`bool`, defaults to `False`):
Whether the decoder should behave causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
) -> None:
super().__init__()
self.encoder = LTXEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
)
self.decoder = LTXDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
is_causal=decoder_causal,
)
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
latents_std = torch.ones((latent_channels,), requires_grad=False)
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
self.use_framewise_encoding = False
self.use_framewise_decoding = False
# This can be configured based on the amount of GPU memory available.
# `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
# Setting it to higher values results in higher memory usage.
self.num_sample_frames_batch_size = 16
self.num_latent_frames_batch_size = 2
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
batch_size, num_channels, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
if self.use_framewise_encoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
)
row.append(time)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
batch_size, num_channels, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
raise NotImplementedError(
"Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to "
"quality issues caused by splitting inference across frame dimension. If you believe this "
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
row.append(time)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec,)
return dec
......@@ -17,6 +17,7 @@ if is_torch_available():
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
# Copyright 2024 The Genmo team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTXRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
base_num_frames: int = 20,
base_height: int = 2048,
base_width: int = 2048,
patch_size: int = 1,
patch_size_t: int = 1,
theta: float = 10000.0,
) -> None:
super().__init__()
self.dim = dim
self.base_num_frames = base_num_frames
self.base_height = base_height
self.base_width = base_width
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.theta = theta
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
# Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if rope_interpolation_scale is not None:
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
grid = grid.flatten(2, 4).transpose(1, 2)
start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
math.log(start, self.theta),
math.log(end, self.theta),
self.dim // 6,
device=hidden_states.device,
dtype=torch.float32,
)
freqs = freqs * math.pi / 2.0
freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
freqs = freqs.transpose(-1, -2).flatten(2)
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % 6 != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
return cos_freqs, sin_freqs
@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
attention_out_bias: bool = True,
eps: float = 1e-6,
elementwise_affine: bool = False,
):
super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
attn_hidden_states = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
return hidden_states
@maybe_allow_in_graph
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
Args:
in_channels (`int`, defaults to `128`):
The number of channels in the input.
out_channels (`int`, defaults to `128`):
The number of channels in the output.
patch_size (`int`, defaults to `1`):
The size of the spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of the tmeporal patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
cross_attention_dim (`int`, defaults to `2048 `):
The number of channels for cross attention heads.
num_layers (`int`, defaults to `28`):
The number of layers of Transformer blocks to use.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
The normalization layer to use.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
patch_size_t: int = 1,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
cross_attention_dim: int = 2048,
num_layers: int = 28,
activation_fn: str = "gelu-approximate",
qk_norm: str = "rms_norm_across_heads",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 4096,
attention_bias: bool = True,
attention_out_bias: bool = True,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
self.proj_in = nn.Linear(in_channels, inner_dim)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.rope = LTXRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
base_width=2048,
patch_size=patch_size,
patch_size_t=patch_size_t,
theta=10000.0,
)
self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=cross_attention_dim,
qk_norm=qk_norm,
activation_fn=activation_fn,
attention_bias=attention_bias,
attention_out_bias=attention_out_bias,
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
)
for _ in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
return_dict: bool = True,
) -> torch.Tensor:
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
batch_size = hidden_states.size(0)
hidden_states = self.proj_in(hidden_states)
temb, embedded_timestep = self.time_embed(
timestep.flatten(),
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(batch_size, -1, temb.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
encoder_attention_mask,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def apply_rotary_emb(x, freqs):
cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
......@@ -250,6 +250,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
_import_structure["marigold"].extend(
[
......@@ -585,6 +586,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaText2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKLLTXVideo
from ...models.transformers import LTXVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LTXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LTXPipeline
>>> from diffusers.utils import export_to_video
>>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> video = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=704,
... height=480,
... num_frames=161,
... num_inference_steps=50,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=24)
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for text-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
Args:
transformer ([`LTXVideoTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLLTXVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTXVideo,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: LTXVideoTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32
self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8
self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1
self.transformer_temporal_patch_size = (
self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
)
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
@staticmethod
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
@staticmethod
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def prepare_latents(
self,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
width: int = 704,
num_frames: int = 161,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
height = height // self.vae_spatial_compression_ratio
width = width // self.vae_spatial_compression_ratio
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 704,
num_frames: int = 161,
frame_rate: int = 25,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `512`):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, defaults to `704`):
The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `161`):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `128 `):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
video = latents
else:
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKLLTXVideo
from ...models.transformers import LTXVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LTXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LTXImageToVideoPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
... )
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> video = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=704,
... height=480,
... num_frames=161,
... num_inference_steps=50,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=24)
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for image-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
Args:
transformer ([`LTXVideoTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLLTXVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTXVideo,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: LTXVideoTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32
self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8
self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1
self.transformer_temporal_patch_size = (
self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
)
self.default_height = 512
self.default_width = 704
self.default_frames = 121
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def prepare_latents(
self,
image: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
width: int = 704,
num_frames: int = 161,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
height = height // self.vae_spatial_compression_ratio
width = width // self.vae_spatial_compression_ratio
num_frames = (
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)
shape = (batch_size, num_channels_latents, num_frames, height, width)
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents.to(device=device, dtype=dtype), conditioning_mask
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
for i in range(batch_size)
]
else:
init_latents = [
retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
conditioning_mask[:, :, 0] = 1.0
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents, conditioning_mask
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 704,
num_frames: int = 161,
frame_rate: int = 25,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PipelineImageInput`):
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `512`):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, defaults to `704`):
The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `161`):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `128 `):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=prompt_embeds.dtype)
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
if self.do_classifier_free_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
# compute the previous noisy sample x_t -> x_t-1
noise_pred = self._unpack_latents(
noise_pred,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
noise_pred = noise_pred[:, :, 1:]
noise_latents = latents[:, :, 1:]
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
video = latents
else:
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LTXPipelineOutput(BaseOutput):
r"""
Output class for LTX pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
......@@ -75,6 +75,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
invert_sigmas: bool = False,
shift_terminal: Optional[float] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
......@@ -181,6 +182,27 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
Reference:
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
stretched_t = 1 - (one_minus_z / scale_factor)
return stretched_t
def set_timesteps(
self,
num_inference_steps: int = None,
......@@ -216,6 +238,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment