"graphbolt/vscode:/vscode.git/clone" did not exist on "969276eb1da99d40ddb18daf075c9576212ba2eb"
Unverified Commit 4b557132 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[core] LTX Video 0.9.1 (#10330)

* update

* make style

* update

* update

* update

* make style

* single file related changes

* update

* fix

* update single file urls and docs

* update

* fix
parent 851dfa30
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX
# LTX Video
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
......@@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
Available models:
| Model name | Recommended dtype |
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
```python
import torch
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
transformer = LTXVideoTransformer3DModel.from_single_file(
single_file_url, torch_dtype=torch.bfloat16
......@@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24)
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
<!-- TODO(aryan): Update this when official weights are supported -->
Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
```python
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
height=512,
num_frames=161,
decode_timestep=0.03,
decode_noise_scale=0.025,
num_inference_steps=50,
).frames[0]
export_to_video(video, "output.mp4", fps=24)
```
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
## LTXPipeline
......
import argparse
from pathlib import Path
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
......@@ -21,7 +23,9 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"k_norm": "norm_k",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"vae": remove_keys_,
}
VAE_KEYS_RENAME_DICT = {
# decoder
......@@ -54,10 +58,31 @@ VAE_KEYS_RENAME_DICT = {
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
......@@ -80,13 +105,16 @@ def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
):
PREFIX_KEY = ""
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
......@@ -97,16 +125,21 @@ def convert_transformer(
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae(ckpt_path: str, dtype: torch.dtype):
def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "vae."
original_state_dict = get_state_dict(load_file(ckpt_path))
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
with init_empty_weights():
vae = AutoencoderKLLTXVideo(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = key[len(PREFIX_KEY) :]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
......@@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype):
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_vae_config(version: str) -> Dict[str, Any]:
if version == "0.9.0":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"timestep_conditioning": False,
}
elif version == "0.9.1":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
......@@ -139,6 +222,9 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
)
return parser.parse_args()
......@@ -161,6 +247,7 @@ if __name__ == "__main__":
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
......@@ -169,13 +256,14 @@ if __name__ == "__main__":
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.vae_ckpt_path is not None:
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
config = get_vae_config(args.version)
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
......
......@@ -157,7 +157,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
......@@ -605,7 +606,10 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
encoder_key = "encoder.project_in.conv.conv.bias"
......@@ -2338,12 +2342,32 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"per_channel_statistics.std-of-means": "latents_std",
}
VAE_091_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
......
......@@ -22,6 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
......@@ -109,7 +110,9 @@ class LTXVideoResnetBlock3d(nn.Module):
elementwise_affine: bool = False,
non_linearity: str = "swish",
is_causal: bool = True,
):
inject_noise: bool = False,
timestep_conditioning: bool = False,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
......@@ -135,18 +138,54 @@ class LTXVideoResnetBlock3d(nn.Module):
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
self.per_channel_scale1 = None
self.per_channel_scale2 = None
if inject_noise:
self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1))
self.scale_shift_table = None
if timestep_conditioning:
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
def forward(
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.scale_shift_table is not None:
temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None]
shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale_1) + shift_1
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.per_channel_scale1 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...]
hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.scale_shift_table is not None:
hidden_states = hidden_states * (1 + scale_2) + shift_2
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.per_channel_scale2 is not None:
spatial_shape = hidden_states.shape[-2:]
spatial_noise = torch.randn(
spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype
)[None]
hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...]
if self.norm3 is not None:
inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1)
......@@ -163,12 +202,16 @@ class LTXVideoUpsampler3d(nn.Module):
in_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.residual = residual
self.upscale_factor = upscale_factor
out_channels = in_channels * stride[0] * stride[1] * stride[2]
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
......@@ -181,6 +224,15 @@ class LTXVideoUpsampler3d(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.residual:
residual = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor
residual = residual.repeat(1, repeats, 1, 1, 1)
residual = residual[:, :, self.stride[0] - 1 :]
hidden_states = self.conv(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
......@@ -188,6 +240,9 @@ class LTXVideoUpsampler3d(nn.Module):
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states[:, :, self.stride[0] - 1 :]
if self.residual:
hidden_states = hidden_states + residual
return hidden_states
......@@ -273,7 +328,12 @@ class LTXVideoDownBlock3D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
......@@ -285,16 +345,18 @@ class LTXVideoDownBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if self.conv_out is not None:
hidden_states = self.conv_out(hidden_states)
hidden_states = self.conv_out(hidden_states, temb, generator)
return hidden_states
......@@ -329,9 +391,15 @@ class LTXVideoMidBlock3d(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
) -> None:
super().__init__()
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
resnets = []
for _ in range(num_layers):
resnets.append(
......@@ -342,15 +410,32 @@ class LTXVideoMidBlock3d(nn.Module):
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXMidBlock3D` class."""
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
......@@ -360,9 +445,11 @@ class LTXVideoMidBlock3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
return hidden_states
......@@ -403,11 +490,19 @@ class LTXVideoUpBlock3d(nn.Module):
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
upscale_factor: int = 1,
):
super().__init__()
out_channels = out_channels or in_channels
self.time_embedder = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0)
self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXVideoResnetBlock3d(
......@@ -417,11 +512,23 @@ class LTXVideoUpBlock3d(nn.Module):
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
stride=(2, 2, 2),
is_causal=is_causal,
residual=upsample_residual,
upscale_factor=upscale_factor,
)
]
)
resnets = []
for _ in range(num_layers):
......@@ -433,15 +540,32 @@ class LTXVideoUpBlock3d(nn.Module):
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
)
)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
if self.conv_in is not None:
hidden_states = self.conv_in(hidden_states)
hidden_states = self.conv_in(hidden_states, temb, generator)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
......@@ -456,9 +580,11 @@ class LTXVideoUpBlock3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, generator
)
else:
hidden_states = resnet(hidden_states)
hidden_states = resnet(hidden_states, temb, generator)
return hidden_states
......@@ -623,6 +749,8 @@ class LTXVideoDecoder3d(nn.Module):
Epsilon value for ResNet normalization layers.
is_causal (`bool`, defaults to `False`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
timestep_conditioning (`bool`, defaults to `False`):
Whether to condition the model on timesteps.
"""
def __init__(
......@@ -636,6 +764,10 @@ class LTXVideoDecoder3d(nn.Module):
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: Tuple[bool, ...] = (False, False, False, False),
timestep_conditioning: bool = False,
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
) -> None:
super().__init__()
......@@ -646,6 +778,9 @@ class LTXVideoDecoder3d(nn.Module):
block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
layers_per_block = tuple(reversed(layers_per_block))
inject_noise = tuple(reversed(inject_noise))
upsample_residual = tuple(reversed(upsample_residual))
upsample_factor = tuple(reversed(upsample_factor))
output_channel = block_out_channels[0]
self.conv_in = LTXVideoCausalConv3d(
......@@ -653,15 +788,20 @@ class LTXVideoDecoder3d(nn.Module):
)
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
in_channels=output_channel,
num_layers=layers_per_block[0],
resnet_eps=resnet_norm_eps,
is_causal=is_causal,
inject_noise=inject_noise[0],
timestep_conditioning=timestep_conditioning,
)
# up blocks
num_block_out_channels = len(block_out_channels)
self.up_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i]
input_channel = output_channel // upsample_factor[i]
output_channel = block_out_channels[i] // upsample_factor[i]
up_block = LTXVideoUpBlock3d(
in_channels=input_channel,
......@@ -670,6 +810,10 @@ class LTXVideoDecoder3d(nn.Module):
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
upscale_factor=upsample_factor[i],
)
self.up_blocks.append(up_block)
......@@ -681,9 +825,16 @@ class LTXVideoDecoder3d(nn.Module):
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
if timestep_conditioning:
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
......@@ -694,17 +845,33 @@ class LTXVideoDecoder3d(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb
)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb)
else:
hidden_states = self.mid_block(hidden_states)
hidden_states = self.mid_block(hidden_states, temb)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
hidden_states = up_block(hidden_states, temb)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
if self.time_embedder is not None:
temb = self.time_embedder(
timestep=temb.flatten(),
resolution=None,
aspect_ratio=None,
batch_size=hidden_states.size(0),
hidden_dtype=hidden_states.dtype,
)
temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1))
temb = temb + self.scale_shift_table[None, ..., None, None, None]
shift, scale = temb.unbind(dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
......@@ -767,8 +934,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
......@@ -792,13 +966,17 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.decoder = LTXVideoDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
is_causal=decoder_causal,
timestep_conditioning=timestep_conditioning,
inject_noise=decoder_inject_noise,
upsample_residual=upsample_residual,
upsample_factor=upsample_factor,
)
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
......@@ -937,13 +1115,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def _decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
return self.tiled_decode(z, temb, return_dict=return_dict)
if self.use_framewise_decoding:
# TODO(aryan): requires investigation
......@@ -953,7 +1133,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
dec = self.decoder(z)
dec = self.decoder(z, temb)
if not return_dict:
return (dec,)
......@@ -961,7 +1141,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
......@@ -976,10 +1158,15 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
if temb is not None:
decoded_slices = [
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
]
else:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
decoded = self._decode(z, temb).sample
if not return_dict:
return (decoded,)
......@@ -1061,7 +1248,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
def tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
......@@ -1102,7 +1291,9 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls."
)
else:
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width])
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb
)
row.append(time)
rows.append(row)
......@@ -1130,6 +1321,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def forward(
self,
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
......@@ -1140,7 +1332,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
dec = self.decode(z, temb)
if not return_dict:
return (dec,)
return dec
......@@ -511,6 +511,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
......@@ -563,6 +565,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -753,7 +759,25 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
......
......@@ -571,6 +571,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
......@@ -625,6 +627,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
......@@ -849,7 +855,25 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
video = self.vae.decode(latents, return_dict=False)[0]
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
......
......@@ -52,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
transformer_cls = LTXVideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import AutoencoderKLLTXVideo
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTXVideoEncoder3d",
"LTXVideoDecoder3d",
"LTXVideoDownBlock3D",
"LTXVideoMidBlock3d",
"LTXVideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (16, 32, 64),
"layers_per_block": (1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
timestep = torch.tensor([0.05] * batch_size, device=torch_device)
return {"sample": image, "temb": timestep}
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTXVideoEncoder3d",
"LTXVideoDecoder3d",
"LTXVideoDownBlock3D",
"LTXVideoMidBlock3d",
"LTXVideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
......@@ -63,10 +63,19 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
spatio_temporal_scaling=(True, True, False, False),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
......
......@@ -68,10 +68,19 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
spatio_temporal_scaling=(True, True, False, False),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment