Unverified Commit 7b904941 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Cosmos (#10660)



* begin transformer conversion

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* update

* add conversion script

* add pipeline

* make fix-copies

* remove einops

* update docs

* gradient checkpointing

* add transformer test

* update

* debug

* remove prints

* match sigmas

* add vae pt. 1

* finish CV* vae

* update

* update

* update

* update

* update

* update

* make fix-copies

* update

* make fix-copies

* fix

* update

* update

* make fix-copies

* update

* update tests

* handle device and dtype for safety checker; required in latest diffusers

* remove enable_gqa and use repeat_interleave instead

* enforce safety checker; use dummy checker in fast tests

* add review suggestion for ONNX export
Co-Authored-By: default avatarAsfiya Baig <asfiyab@nvidia.com>

* fix safety_checker issues when not passed explicitly

We could either do what's done in this commit, or update the Cosmos examples to explicitly pass the safety checker

* use cosmos guardrail package

* auto format docs

* update conversion script to support 14B models

* update name CosmosPipeline -> CosmosTextToWorldPipeline

* update docs

* fix docs

* fix group offload test failing for vae

---------
Co-authored-by: default avatarAsfiya Baig <asfiyab@nvidia.com>
parent fb29132b
......@@ -295,6 +295,8 @@
title: CogView4Transformer2DModel
- local: api/models/consisid_transformer3d
title: ConsisIDTransformer3DModel
- local: api/models/cosmos_transformer3d
title: CosmosTransformer3DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
......@@ -363,6 +365,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
......@@ -433,6 +437,8 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
......
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLCosmos
[Cosmos Tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer).
Supported models:
- [nvidia/Cosmos-1.0-Tokenizer-CV8x8x8](https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-CV8x8x8)
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLCosmos
vae = AutoencoderKLCosmos.from_pretrained("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", subfolder="vae")
```
## AutoencoderKLCosmos
[[autodoc]] AutoencoderKLCosmos
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# CosmosTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
The model can be loaded with the following code snippet.
```python
from diffusers import CosmosTransformer3DModel
transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## CosmosTransformer3DModel
[[autodoc]] CosmosTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Cosmos
[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
- all
- __call__
## CosmosVideoToWorldPipeline
[[autodoc]] CosmosVideoToWorldPipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
import argparse
import pathlib
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
block_index = int(key.split(".")[1].removeprefix("block"))
new_key = key
old_prefix = f"blocks.block{block_index}"
new_prefix = f"transformer_blocks.{block_index}"
new_key = new_prefix + new_key.removeprefix(old_prefix)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1",
".blocks.1.block.attn": ".attn2",
".blocks.2.block": ".ff",
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
"to_q.0": "to_q",
"to_q.1": "norm_q",
"to_k.0": "to_k",
"to_k.1": "norm_k",
"to_v.0": "to_v",
"layer1": "net.0.proj",
"layer2": "net.2",
"proj.1": "proj",
"x_embedder": "patch_embed",
"extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_,
"logvar.1.weight": remove_keys_,
"pos_embedder.seq": remove_keys_,
}
TRANSFORMER_CONFIGS = {
"Cosmos-1.0-Diffusion-7B-Text2World": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 32,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 1.0, 1.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-1.0-Diffusion-7B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 32,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 1.0, 1.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-1.0-Diffusion-14B-Text2World": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 2.0, 2.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-1.0-Diffusion-14B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 2.0, 2.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
}
VAE_KEYS_RENAME_DICT = {
"down.0": "down_blocks.0",
"down.1": "down_blocks.1",
"down.2": "down_blocks.2",
"up.0": "up_blocks.2",
"up.1": "up_blocks.1",
"up.2": "up_blocks.0",
".block.": ".resnets.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"mid.block_1": "mid_block.resnets.0",
"mid.attn_1.0": "mid_block.attentions.0",
"mid.attn_1.1": "mid_block.temp_attentions.0",
"mid.block_2": "mid_block.resnets.1",
".q.conv3d": ".to_q",
".k.conv3d": ".to_k",
".v.conv3d": ".to_v",
".proj_out.conv3d": ".to_out.0",
".0.conv3d": ".conv_s",
".1.conv3d": ".conv_t",
"conv1.conv3d": "conv1",
"conv2.conv3d": "conv2",
"conv3.conv3d": "conv3",
"nin_shortcut.conv3d": "conv_shortcut",
"quant_conv.conv3d": "quant_conv",
"post_quant_conv.conv3d": "post_quant_conv",
}
VAE_SPECIAL_KEYS_REMAP = {
"wavelets": remove_keys_,
"_arange": remove_keys_,
"patch_size_buffer": remove_keys_,
}
VAE_CONFIGS = {
"CV8x8x8-0.1": {
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 16,
"encoder_block_out_channels": (128, 256, 512, 512),
"decode_block_out_channels": (256, 512, 512, 512),
"attention_resolutions": (32,),
"resolution": 1024,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 8,
"latents_mean": None,
"latents_std": None,
},
},
"CV8x8x8-1.0": {
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 16,
"encoder_block_out_channels": (128, 256, 512, 512),
"decode_block_out_channels": (256, 512, 512, 512),
"attention_resolutions": (32,),
"resolution": 1024,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 8,
"latents_mean": None,
"latents_std": None,
},
},
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def convert_transformer(transformer_type: str, ckpt_path: str):
PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
with init_empty_weights():
config = TRANSFORMER_CONFIGS[transformer_type]
transformer = CosmosTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae(vae_type: str):
model_name = VAE_CONFIGS[vae_type]["name"]
snapshot_directory = snapshot_download(model_name, repo_type="model")
directory = pathlib.Path(snapshot_directory)
autoencoder_file = directory / "autoencoder.jit"
mean_std_file = directory / "mean_std.pt"
original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
if mean_std_file.exists():
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
else:
mean_std = (None, None)
config = VAE_CONFIGS[vae_type]["diffusers_config"]
config.update(
{
"latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
"latents_std": mean_std[1].detach().cpu().numpy().tolist(),
}
)
vae = AutoencoderKLCosmos(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_type is not None:
vae = convert_vae(args.vae_type)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
# So, the sigma_min values that is used is the default value of 0.002.
scheduler = EDMEulerScheduler(
sigma_min=0.002,
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)
pipe = CosmosTextToWorldPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
......@@ -148,6 +148,7 @@ else:
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
......@@ -166,6 +167,7 @@ else:
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"FluxControlNetModel",
......@@ -357,6 +359,9 @@ else:
"CogView3PlusPipeline",
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
"CycleDiffusionPipeline",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
......@@ -745,6 +750,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
......@@ -763,6 +769,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxControlNetModel,
......@@ -933,6 +940,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView3PlusPipeline,
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline,
CycleDiffusionPipeline,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
......
......@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
......@@ -75,6 +76,7 @@ if is_torch_available():
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
......@@ -114,6 +116,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
......@@ -151,6 +154,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
......
......@@ -203,8 +203,8 @@ class Attention(nn.Module):
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
......
......@@ -3,6 +3,7 @@ from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
......
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import get_logger
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, IdentityDistribution
logger = get_logger(__name__)
# fmt: off
# These latents and means are from CV8x8x8-1.0. Each checkpoint has different values, but since this is the main VAE used,
# we will default to these values.
LATENTS_MEAN = [0.11362758, -0.0171717, 0.03071163, 0.02046862, 0.01931456, 0.02138567, 0.01999342, 0.02189187, 0.02011935, 0.01872694, 0.02168613, 0.02207148, 0.01986941, 0.01770413, 0.02067643, 0.02028245, 0.19125476, 0.04556972, 0.0595558, 0.05315534, 0.05496629, 0.05356264, 0.04856596, 0.05327453, 0.05410472, 0.05597149, 0.05524866, 0.05181874, 0.05071663, 0.05204537, 0.0564108, 0.05518042, 0.01306714, 0.03341161, 0.03847246, 0.02810185, 0.02790166, 0.02920026, 0.02823597, 0.02631033, 0.0278531, 0.02880507, 0.02977769, 0.03145441, 0.02888389, 0.03280773, 0.03484927, 0.03049198, -0.00197727, 0.07534957, 0.04963879, 0.05530893, 0.05410828, 0.05252541, 0.05029899, 0.05321025, 0.05149245, 0.0511921, 0.04643495, 0.04604527, 0.04631618, 0.04404101, 0.04403536, 0.04499495, -0.02994183, -0.04787003, -0.01064558, -0.01779824, -0.01490502, -0.02157517, -0.0204778, -0.02180816, -0.01945375, -0.02062863, -0.02192209, -0.02520639, -0.02246656, -0.02427533, -0.02683363, -0.02762006, 0.08019473, -0.13005368, -0.07568636, -0.06082374, -0.06036175, -0.05875364, -0.05921887, -0.05869788, -0.05273941, -0.052565, -0.05346428, -0.05456541, -0.053657, -0.05656897, -0.05728589, -0.05321847, 0.16718403, -0.00390146, 0.0379406, 0.0356561, 0.03554131, 0.03924074, 0.03873615, 0.04187329, 0.04226924, 0.04378717, 0.04684274, 0.05117614, 0.04547792, 0.05251586, 0.05048339, 0.04950784, 0.09564418, 0.0547128, 0.08183969, 0.07978633, 0.08076023, 0.08108605, 0.08011818, 0.07965573, 0.08187773, 0.08350263, 0.08101469, 0.0786941, 0.0774442, 0.07724521, 0.07830418, 0.07599796, -0.04987567, 0.05923908, -0.01058746, -0.01177603, -0.01116162, -0.01364149, -0.01546014, -0.0117213, -0.01780043, -0.01648314, -0.02100247, -0.02104417, -0.02482123, -0.02611689, -0.02561143, -0.02597336, -0.05364667, 0.08211684, 0.04686937, 0.04605641, 0.04304186, 0.0397355, 0.03686767, 0.04087112, 0.03704741, 0.03706401, 0.03120073, 0.03349091, 0.03319963, 0.03205781, 0.03195127, 0.03180481, 0.16427967, -0.11048453, -0.04595276, -0.04982893, -0.05213465, -0.04809378, -0.05080318, -0.04992863, -0.04493337, -0.0467619, -0.04884703, -0.04627892, -0.04913311, -0.04955709, -0.04533982, -0.04570218, -0.10612928, -0.05121198, -0.06761009, -0.07251801, -0.07265285, -0.07417855, -0.07202412, -0.07499027, -0.07625481, -0.07535747, -0.07638787, -0.07920305, -0.07596069, -0.07959418, -0.08265036, -0.07955471, -0.16888915, 0.0753242, 0.04062594, 0.03375093, 0.03337452, 0.03699376, 0.03651138, 0.03611023, 0.03555622, 0.03378554, 0.0300498, 0.03395559, 0.02941847, 0.03156432, 0.03431173, 0.03016853, -0.03415358, -0.01699573, -0.04029295, -0.04912157, -0.0498858, -0.04917918, -0.04918056, -0.0525189, -0.05325506, -0.05341973, -0.04983329, -0.04883146, -0.04985548, -0.04736718, -0.0462027, -0.04836091, 0.02055675, 0.03419799, -0.02907669, -0.04350509, -0.04156144, -0.04234421, -0.04446109, -0.04461774, -0.04882839, -0.04822346, -0.04502493, -0.0506244, -0.05146913, -0.04655267, -0.04862994, -0.04841615, 0.20312774, -0.07208502, -0.03635615, -0.03556088, -0.04246174, -0.04195838, -0.04293778, -0.04071276, -0.04240569, -0.04125213, -0.04395144, -0.03959096, -0.04044993, -0.04015875, -0.04088107, -0.03885176]
LATENTS_STD = [0.56700271, 0.65488982, 0.65589428, 0.66524369, 0.66619784, 0.6666382, 0.6720838, 0.66955978, 0.66928875, 0.67108786, 0.67092526, 0.67397463, 0.67894882, 0.67668313, 0.67769569, 0.67479557, 0.85245121, 0.8688373, 0.87348086, 0.88459337, 0.89135885, 0.8910504, 0.89714909, 0.89947474, 0.90201765, 0.90411824, 0.90692616, 0.90847772, 0.90648711, 0.91006982, 0.91033435, 0.90541548, 0.84960359, 0.85863352, 0.86895317, 0.88460612, 0.89245003, 0.89451706, 0.89931005, 0.90647358, 0.90338236, 0.90510076, 0.91008312, 0.90961218, 0.9123717, 0.91313171, 0.91435546, 0.91565102, 0.91877103, 0.85155135, 0.857804, 0.86998034, 0.87365264, 0.88161767, 0.88151032, 0.88758916, 0.89015514, 0.89245576, 0.89276224, 0.89450496, 0.90054202, 0.89994133, 0.90136105, 0.90114892, 0.77755755, 0.81456852, 0.81911844, 0.83137071, 0.83820474, 0.83890373, 0.84401101, 0.84425181, 0.84739357, 0.84798753, 0.85249585, 0.85114998, 0.85160935, 0.85626358, 0.85677862, 0.85641026, 0.69903517, 0.71697885, 0.71696913, 0.72583169, 0.72931731, 0.73254126, 0.73586977, 0.73734969, 0.73664582, 0.74084908, 0.74399322, 0.74471819, 0.74493188, 0.74824578, 0.75024873, 0.75274801, 0.8187142, 0.82251883, 0.82616025, 0.83164483, 0.84072375, 0.8396467, 0.84143305, 0.84880769, 0.8503468, 0.85196948, 0.85211051, 0.85386664, 0.85410017, 0.85439342, 0.85847849, 0.85385275, 0.67583984, 0.68259847, 0.69198853, 0.69928843, 0.70194328, 0.70467001, 0.70755547, 0.70917857, 0.71007699, 0.70963502, 0.71064079, 0.71027333, 0.71291167, 0.71537536, 0.71902508, 0.71604162, 0.72450989, 0.71979928, 0.72057378, 0.73035461, 0.73329622, 0.73660028, 0.73891461, 0.74279994, 0.74105692, 0.74002433, 0.74257588, 0.74416119, 0.74543899, 0.74694443, 0.74747062, 0.74586403, 0.90176988, 0.90990674, 0.91106802, 0.92163783, 0.92390233, 0.93056196, 0.93482202, 0.93642414, 0.93858379, 0.94064975, 0.94078934, 0.94325715, 0.94955301, 0.94814706, 0.95144123, 0.94923073, 0.49853548, 0.64968109, 0.6427654, 0.64966393, 0.6487664, 0.65203559, 0.6584242, 0.65351611, 0.65464371, 0.6574859, 0.65626335, 0.66123748, 0.66121179, 0.66077942, 0.66040152, 0.66474909, 0.61986589, 0.69138134, 0.6884557, 0.6955843, 0.69765401, 0.70015347, 0.70529598, 0.70468754, 0.70399523, 0.70479989, 0.70887572, 0.71126866, 0.7097227, 0.71249932, 0.71231949, 0.71175605, 0.35586974, 0.68723857, 0.68973219, 0.69958478, 0.6943453, 0.6995818, 0.70980215, 0.69899458, 0.70271689, 0.70095056, 0.69912851, 0.70522696, 0.70392174, 0.70916915, 0.70585734, 0.70373541, 0.98101336, 0.89024764, 0.89607251, 0.90678179, 0.91308665, 0.91812348, 0.91980827, 0.92480654, 0.92635667, 0.92887944, 0.93338072, 0.93468094, 0.93619436, 0.93906063, 0.94191772, 0.94471723, 0.83202779, 0.84106231, 0.84463632, 0.85829508, 0.86319661, 0.86751342, 0.86914337, 0.87085921, 0.87286359, 0.87537396, 0.87931138, 0.88054478, 0.8811838, 0.88872558, 0.88942474, 0.88934827, 0.44025335, 0.63061613, 0.63110614, 0.63601959, 0.6395812, 0.64104342, 0.65019929, 0.6502797, 0.64355946, 0.64657205, 0.64847094, 0.64728117, 0.64972943, 0.65162975, 0.65328044, 0.64914775]
_WAVELETS = {
"haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
"rearrange": torch.tensor([1.0, 1.0]),
}
# fmt: on
class CosmosCausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_size: Union[int, Tuple[int, int, int]] = (3, 3, 3),
dilation: Union[int, Tuple[int, int, int]] = (1, 1, 1),
stride: Union[int, Tuple[int, int, int]] = (1, 1, 1),
padding: int = 1,
pad_mode: str = "constant",
) -> None:
kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
dilation = (dilation, dilation, dilation) if isinstance(dilation, int) else dilation
stride = (stride, stride, stride) if isinstance(stride, int) else stride
_, height_kernel_size, width_kernel_size = kernel_size
assert height_kernel_size % 2 == 1 and width_kernel_size % 2 == 1
super().__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
)
self.pad_mode = pad_mode
self.temporal_pad = dilation[0] * (kernel_size[0] - 1) + (1 - stride[0])
self.spatial_pad = (padding, padding, padding, padding)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
return super().forward(hidden_states)
class CosmosCausalGroupNorm(torch.nn.Module):
def __init__(self, in_channels: int, num_groups: int = 1):
super().__init__()
self.norm = nn.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
)
self.num_groups = num_groups
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.num_groups == 1:
batch_size = hidden_states.size(0)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
0, 2, 1, 3, 4
) # [B * T, C, H, W] -> [B, C, T, H, W]
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class CosmosPatchEmbed3d(nn.Module):
def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
dtype = hidden_states.dtype
wavelets = self.wavelets
n = wavelets.shape[0]
g = hidden_states.shape[1]
hl = wavelets.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
hh = (wavelets * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
hh = hh.to(dtype=dtype)
hl = hl.to(dtype=dtype)
# Handles temporal axis
hidden_states = F.pad(hidden_states, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(
dtype
)
xl = F.conv3d(hidden_states, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
xh = F.conv3d(hidden_states, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
# Handles spatial axes
xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
hidden_states = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
if rescale:
hidden_states = hidden_states / 8**0.5
return hidden_states
def _haar(self, hidden_states: torch.Tensor) -> torch.Tensor:
xi, xv = torch.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
hidden_states = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
for _ in range(int(math.log2(self.patch_size))):
hidden_states = self._dwt(hidden_states, rescale=True)
return hidden_states
def _arrange(self, hidden_states: torch.Tensor) -> torch.Tensor:
xi, xv = torch.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
hidden_states = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.patch_size
hidden_states = torch.reshape(batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4).contiguous()
return hidden_states
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.patch_method == "haar":
return self._haar(hidden_states)
elif self.patch_method == "rearrange":
return self._arrange(hidden_states)
else:
raise ValueError(f"Unsupported patch method: {self.patch_method}")
class CosmosUnpatcher3d(nn.Module):
def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
super().__init__()
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=False,
)
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
device = hidden_states.device
dtype = hidden_states.dtype
h = self.wavelets.to(device)
g = hidden_states.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
hh = (h * ((-1) ** self._arange.to(device))).reshape(1, 1, -1).repeat(g, 1, 1)
hl = hl.to(dtype=dtype)
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(hidden_states, 8, dim=1)
# Handle height transposed convolutions
xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xll = F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll
xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xlh = F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh
xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhl = F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl
xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
xhh = F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh
# Handles width transposed convolutions
xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xl = F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl
xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
xh = F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh
# Handles time axis transposed convolutions
hidden_states = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
hidden_states = (
F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + hidden_states
)
if rescale:
hidden_states = hidden_states * 8**0.5
return hidden_states
def _ihaar(self, hidden_states: torch.Tensor) -> torch.Tensor:
for _ in range(int(math.log2(self.patch_size))):
hidden_states = self._idwt(hidden_states, rescale=True)
hidden_states = hidden_states[:, :, self.patch_size - 1 :, ...]
return hidden_states
def _irearrange(self, hidden_states: torch.Tensor) -> torch.Tensor:
p = self.patch_size
hidden_states = hidden_states.unflatten(1, (-1, p, p, p))
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
hidden_states = hidden_states[:, :, p - 1 :, ...]
return hidden_states
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.patch_method == "haar":
return self._ihaar(hidden_states)
elif self.patch_method == "rearrange":
return self._irearrange(hidden_states)
else:
raise ValueError("Unknown patch method: " + self.patch_method)
class CosmosConvProjection3d(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.conv_s = CosmosCausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1)
self.conv_t = CosmosCausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_s(hidden_states)
hidden_states = self.conv_t(hidden_states)
return hidden_states
class CosmosResnetBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_groups: int = 1,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = CosmosCausalGroupNorm(in_channels, num_groups)
self.conv1 = CosmosConvProjection3d(in_channels, out_channels)
self.norm2 = CosmosCausalGroupNorm(out_channels, num_groups)
self.dropout = nn.Dropout(dropout)
self.conv2 = CosmosConvProjection3d(out_channels, out_channels)
if in_channels != out_channels:
self.conv_shortcut = CosmosCausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = nn.Identity()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
residual = self.conv_shortcut(residual)
hidden_states = self.norm1(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
return hidden_states + residual
class CosmosDownsample3d(nn.Module):
def __init__(
self,
in_channels: int,
spatial_downsample: bool = True,
temporal_downsample: bool = True,
) -> None:
super().__init__()
self.spatial_downsample = spatial_downsample
self.temporal_downsample = temporal_downsample
self.conv1 = nn.Identity()
self.conv2 = nn.Identity()
self.conv3 = nn.Identity()
if spatial_downsample:
self.conv1 = CosmosCausalConv3d(
in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=0
)
if temporal_downsample:
self.conv2 = CosmosCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=0
)
if spatial_downsample or temporal_downsample:
self.conv3 = CosmosCausalConv3d(
in_channels, in_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=0
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if not self.spatial_downsample and not self.temporal_downsample:
return hidden_states
if self.spatial_downsample:
pad = (0, 1, 0, 1, 0, 0)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
conv_out = self.conv1(hidden_states)
pool_out = F.avg_pool3d(hidden_states, kernel_size=(1, 2, 2), stride=(1, 2, 2))
hidden_states = conv_out + pool_out
if self.temporal_downsample:
hidden_states = torch.cat([hidden_states[:, :, :1, ...], hidden_states], dim=2)
conv_out = self.conv2(hidden_states)
pool_out = F.avg_pool3d(hidden_states, kernel_size=(2, 1, 1), stride=(2, 1, 1))
hidden_states = conv_out + pool_out
hidden_states = self.conv3(hidden_states)
return hidden_states
class CosmosUpsample3d(nn.Module):
def __init__(
self,
in_channels: int,
spatial_upsample: bool = True,
temporal_upsample: bool = True,
) -> None:
super().__init__()
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.conv1 = nn.Identity()
self.conv2 = nn.Identity()
self.conv3 = nn.Identity()
if temporal_upsample:
self.conv1 = CosmosCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=0
)
if spatial_upsample:
self.conv2 = CosmosCausalConv3d(
in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=1
)
if spatial_upsample or temporal_upsample:
self.conv3 = CosmosCausalConv3d(
in_channels, in_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=0
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if not self.spatial_upsample and not self.temporal_upsample:
return hidden_states
if self.temporal_upsample:
num_frames = hidden_states.size(2)
time_factor = int(1.0 + 1.0 * (num_frames > 1))
hidden_states = hidden_states.repeat_interleave(int(time_factor), dim=2)
hidden_states = hidden_states[..., time_factor - 1 :, :, :]
hidden_states = self.conv1(hidden_states) + hidden_states
if self.spatial_upsample:
hidden_states = hidden_states.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
hidden_states = self.conv2(hidden_states) + hidden_states
hidden_states = self.conv3(hidden_states)
return hidden_states
class CosmosCausalAttention(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
num_groups: int = 1,
dropout: float = 0.0,
processor: Union["CosmosSpatialAttentionProcessor2_0", "CosmosTemporalAttentionProcessor2_0"] = None,
) -> None:
super().__init__()
self.num_attention_heads = num_attention_heads
self.norm = CosmosCausalGroupNorm(attention_head_dim, num_groups=num_groups)
self.to_q = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
self.to_k = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
self.to_v = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
self.to_out = nn.ModuleList([])
self.to_out.append(
CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
)
self.to_out.append(nn.Dropout(dropout))
self.processor = processor
if self.processor is None:
raise ValueError("CosmosCausalAttention requires a processor.")
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.processor(self, hidden_states=hidden_states, attention_mask=attention_mask)
class CosmosSpatialAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"CosmosSpatialAttentionProcessor2_0 requires PyTorch 2.0 or higher. To use it, please upgrade PyTorch."
)
def __call__(
self, attn: CosmosCausalAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
residual = hidden_states
hidden_states = attn.norm(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# [B, C, T, H, W] -> [B * T, H * W, C]
query = query.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
key = key.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
value = value.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
# [B * T, H * W, C] -> [B * T, N, H * W, C // N]
query = query.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
hidden_states = hidden_states.unflatten(1, (height, width)).unflatten(0, (batch_size, num_frames))
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states + residual
class CosmosTemporalAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"CosmosSpatialAttentionProcessor2_0 requires PyTorch 2.0 or higher. To use it, please upgrade PyTorch."
)
def __call__(
self, attn: CosmosCausalAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
residual = hidden_states
hidden_states = attn.norm(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# [B, C, T, H, W] -> [B * T, H * W, C]
query = query.permute(0, 3, 4, 2, 1).flatten(0, 2)
key = key.permute(0, 3, 4, 2, 1).flatten(0, 2)
value = value.permute(0, 3, 4, 2, 1).flatten(0, 2)
# [B * T, H * W, C] -> [B * T, N, H * W, C // N]
query = query.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
hidden_states = hidden_states.unflatten(0, (batch_size, height, width))
hidden_states = hidden_states.permute(0, 4, 3, 1, 2)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states + residual
class CosmosDownBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int,
dropout: float,
use_attention: bool,
use_downsample: bool,
spatial_downsample: bool,
temporal_downsample: bool,
) -> None:
super().__init__()
resnets, attentions, temp_attentions = [], [], []
in_channel, out_channel = in_channels, out_channels
for _ in range(num_layers):
resnets.append(CosmosResnetBlock3d(in_channel, out_channel, dropout, num_groups=1))
in_channel = out_channel
if use_attention:
attentions.append(
CosmosCausalAttention(
num_attention_heads=1,
attention_head_dim=out_channel,
num_groups=1,
dropout=dropout,
processor=CosmosSpatialAttentionProcessor2_0(),
)
)
temp_attentions.append(
CosmosCausalAttention(
num_attention_heads=1,
attention_head_dim=out_channel,
num_groups=1,
dropout=dropout,
processor=CosmosTemporalAttentionProcessor2_0(),
)
)
else:
attentions.append(None)
temp_attentions.append(None)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
self.temp_attentions = nn.ModuleList(temp_attentions)
self.downsamplers = None
if use_downsample:
self.downsamplers = nn.ModuleList([])
self.downsamplers.append(CosmosDownsample3d(out_channel, spatial_downsample, temporal_downsample))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet, attention, temp_attention in zip(self.resnets, self.attentions, self.temp_attentions):
hidden_states = resnet(hidden_states)
if attention is not None:
hidden_states = attention(hidden_states)
if temp_attention is not None:
num_frames = hidden_states.size(2)
attention_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
hidden_states = temp_attention(hidden_states, attention_mask)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
class CosmosMidBlock3d(nn.Module):
def __init__(self, in_channels: int, num_layers: int, dropout: float, num_groups: int = 1) -> None:
super().__init__()
resnets, attentions, temp_attentions = [], [], []
resnets.append(CosmosResnetBlock3d(in_channels, in_channels, dropout, num_groups))
for _ in range(num_layers):
attentions.append(
CosmosCausalAttention(
num_attention_heads=1,
attention_head_dim=in_channels,
num_groups=num_groups,
dropout=dropout,
processor=CosmosSpatialAttentionProcessor2_0(),
)
)
temp_attentions.append(
CosmosCausalAttention(
num_attention_heads=1,
attention_head_dim=in_channels,
num_groups=num_groups,
dropout=dropout,
processor=CosmosTemporalAttentionProcessor2_0(),
)
)
resnets.append(CosmosResnetBlock3d(in_channels, in_channels, dropout, num_groups))
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attentions)
self.temp_attentions = nn.ModuleList(temp_attentions)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.resnets[0](hidden_states)
for attention, temp_attention, resnet in zip(self.attentions, self.temp_attentions, self.resnets[1:]):
num_frames = hidden_states.size(2)
attention_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
hidden_states = attention(hidden_states)
hidden_states = temp_attention(hidden_states, attention_mask)
hidden_states = resnet(hidden_states)
return hidden_states
class CosmosUpBlock3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int,
dropout: float,
use_attention: bool,
use_upsample: bool,
spatial_upsample: bool,
temporal_upsample: bool,
) -> None:
super().__init__()
resnets, attention, temp_attentions = [], [], []
in_channel, out_channel = in_channels, out_channels
for _ in range(num_layers):
resnets.append(CosmosResnetBlock3d(in_channel, out_channel, dropout, num_groups=1))
in_channel = out_channel
if use_attention:
attention.append(
CosmosCausalAttention(
num_attention_heads=1,
attention_head_dim=out_channel,
num_groups=1,
dropout=dropout,
processor=CosmosSpatialAttentionProcessor2_0(),
)
)
temp_attentions.append(
CosmosCausalAttention(
num_attention_heads=1,
attention_head_dim=out_channel,
num_groups=1,
dropout=dropout,
processor=CosmosTemporalAttentionProcessor2_0(),
)
)
else:
attention.append(None)
temp_attentions.append(None)
self.resnets = nn.ModuleList(resnets)
self.attentions = nn.ModuleList(attention)
self.temp_attentions = nn.ModuleList(temp_attentions)
self.upsamplers = None
if use_upsample:
self.upsamplers = nn.ModuleList([])
self.upsamplers.append(CosmosUpsample3d(out_channel, spatial_upsample, temporal_upsample))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for resnet, attention, temp_attention in zip(self.resnets, self.attentions, self.temp_attentions):
hidden_states = resnet(hidden_states)
if attention is not None:
hidden_states = attention(hidden_states)
if temp_attention is not None:
num_frames = hidden_states.size(2)
attention_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
hidden_states = temp_attention(hidden_states, attention_mask)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class CosmosEncoder3d(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 16,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
num_resnet_blocks: int = 2,
attention_resolutions: Tuple[int, ...] = (32,),
resolution: int = 1024,
patch_size: int = 4,
patch_type: str = "haar",
dropout: float = 0.0,
spatial_compression_ratio: int = 8,
temporal_compression_ratio: int = 8,
) -> None:
super().__init__()
inner_dim = in_channels * patch_size**3
num_spatial_layers = int(math.log2(spatial_compression_ratio)) - int(math.log2(patch_size))
num_temporal_layers = int(math.log2(temporal_compression_ratio)) - int(math.log2(patch_size))
# 1. Input patching & projection
self.patch_embed = CosmosPatchEmbed3d(patch_size, patch_type)
self.conv_in = CosmosConvProjection3d(inner_dim, block_out_channels[0])
# 2. Down blocks
current_resolution = resolution // patch_size
down_blocks = []
for i in range(len(block_out_channels) - 1):
in_channel = block_out_channels[i]
out_channel = block_out_channels[i + 1]
use_attention = current_resolution in attention_resolutions
spatial_downsample = temporal_downsample = False
if i < len(block_out_channels) - 2:
use_downsample = True
spatial_downsample = i < num_spatial_layers
temporal_downsample = i < num_temporal_layers
current_resolution = current_resolution // 2
else:
use_downsample = False
down_blocks.append(
CosmosDownBlock3d(
in_channel,
out_channel,
num_resnet_blocks,
dropout,
use_attention,
use_downsample,
spatial_downsample,
temporal_downsample,
)
)
self.down_blocks = nn.ModuleList(down_blocks)
# 3. Mid block
self.mid_block = CosmosMidBlock3d(block_out_channels[-1], num_layers=1, dropout=dropout, num_groups=1)
# 4. Output norm & projection
self.norm_out = CosmosCausalGroupNorm(block_out_channels[-1], num_groups=1)
self.conv_out = CosmosConvProjection3d(block_out_channels[-1], out_channels)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.patch_embed(hidden_states)
hidden_states = self.conv_in(hidden_states)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
else:
for block in self.down_blocks:
hidden_states = block(hidden_states)
hidden_states = self.mid_block(hidden_states)
hidden_states = self.norm_out(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class CosmosDecoder3d(nn.Module):
def __init__(
self,
in_channels: int = 16,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
num_resnet_blocks: int = 2,
attention_resolutions: Tuple[int, ...] = (32,),
resolution: int = 1024,
patch_size: int = 4,
patch_type: str = "haar",
dropout: float = 0.0,
spatial_compression_ratio: int = 8,
temporal_compression_ratio: int = 8,
) -> None:
super().__init__()
inner_dim = out_channels * patch_size**3
num_spatial_layers = int(math.log2(spatial_compression_ratio)) - int(math.log2(patch_size))
num_temporal_layers = int(math.log2(temporal_compression_ratio)) - int(math.log2(patch_size))
reversed_block_out_channels = list(reversed(block_out_channels))
# 1. Input projection
self.conv_in = CosmosConvProjection3d(in_channels, reversed_block_out_channels[0])
# 2. Mid block
self.mid_block = CosmosMidBlock3d(reversed_block_out_channels[0], num_layers=1, dropout=dropout, num_groups=1)
# 3. Up blocks
current_resolution = (resolution // patch_size) // 2 ** (len(block_out_channels) - 2)
up_blocks = []
for i in range(len(block_out_channels) - 1):
in_channel = reversed_block_out_channels[i]
out_channel = reversed_block_out_channels[i + 1]
use_attention = current_resolution in attention_resolutions
spatial_upsample = temporal_upsample = False
if i < len(block_out_channels) - 2:
use_upsample = True
temporal_upsample = 0 < i < num_temporal_layers + 1
spatial_upsample = temporal_upsample or (
i < num_spatial_layers and num_spatial_layers > num_temporal_layers
)
current_resolution = current_resolution * 2
else:
use_upsample = False
up_blocks.append(
CosmosUpBlock3d(
in_channel,
out_channel,
num_resnet_blocks + 1,
dropout,
use_attention,
use_upsample,
spatial_upsample,
temporal_upsample,
)
)
self.up_blocks = nn.ModuleList(up_blocks)
# 4. Output norm & projection & unpatching
self.norm_out = CosmosCausalGroupNorm(reversed_block_out_channels[-1], num_groups=1)
self.conv_out = CosmosConvProjection3d(reversed_block_out_channels[-1], inner_dim)
self.unpatch_embed = CosmosUnpatcher3d(patch_size, patch_type)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
hidden_states = self.mid_block(hidden_states)
for block in self.up_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states)
else:
hidden_states = block(hidden_states)
hidden_states = self.norm_out(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = self.unpatch_embed(hidden_states)
return hidden_states
class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
Args:
in_channels (`int`, defaults to `3`):
Number of input channels.
out_channels (`int`, defaults to `3`):
Number of output channels.
latent_channels (`int`, defaults to `16`):
Number of latent channels.
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
Number of output channels for each encoder down block.
decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
Number of output channels for each decoder up block.
attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
List of image/video resolutions at which to apply attention.
resolution (`int`, defaults to `1024`):
Base image/video resolution used for computing whether a block should have attention layers.
num_layers (`int`, defaults to `2`):
Number of resnet blocks in each encoder/decoder block.
patch_size (`int`, defaults to `4`):
Patch size used for patching the input image/video.
patch_type (`str`, defaults to `haar`):
Patch type used for patching the input image/video. Can be either `haar` or `rearrange`.
scaling_factor (`float`, defaults to `1.0`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Not applicable in Cosmos,
but we default to 1.0 for consistency.
spatial_compression_ratio (`int`, defaults to `8`):
The spatial compression ratio to apply in the VAE. The number of downsample blocks is determined using
this.
temporal_compression_ratio (`int`, defaults to `8`):
The temporal compression ratio to apply in the VAE. The number of downsample blocks is determined using
this.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
latent_channels: int = 16,
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
decode_block_out_channels: Tuple[int, ...] = (256, 512, 512, 512),
attention_resolutions: Tuple[int, ...] = (32,),
resolution: int = 1024,
num_layers: int = 2,
patch_size: int = 4,
patch_type: str = "haar",
scaling_factor: float = 1.0,
spatial_compression_ratio: int = 8,
temporal_compression_ratio: int = 8,
latents_mean: Optional[List[float]] = LATENTS_MEAN,
latents_std: Optional[List[float]] = LATENTS_STD,
) -> None:
super().__init__()
self.encoder = CosmosEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=encoder_block_out_channels,
num_resnet_blocks=num_layers,
attention_resolutions=attention_resolutions,
resolution=resolution,
patch_size=patch_size,
patch_type=patch_type,
spatial_compression_ratio=spatial_compression_ratio,
temporal_compression_ratio=temporal_compression_ratio,
)
self.decoder = CosmosDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=decode_block_out_channels,
num_resnet_blocks=num_layers,
attention_resolutions=attention_resolutions,
resolution=resolution,
patch_size=patch_size,
patch_type=patch_type,
spatial_compression_ratio=spatial_compression_ratio,
temporal_compression_ratio=temporal_compression_ratio,
)
self.quant_conv = CosmosCausalConv3d(latent_channels, latent_channels, kernel_size=1, padding=0)
self.post_quant_conv = CosmosCausalConv3d(latent_channels, latent_channels, kernel_size=1, padding=0)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
# at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
self.use_framewise_encoding = False
self.use_framewise_decoding = False
# This can be configured based on the amount of GPU memory available.
# `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
# Setting it to higher values results in higher memory usage.
self.num_sample_frames_batch_size = 16
self.num_latent_frames_batch_size = 2
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 512
self.tile_sample_min_width = 512
self.tile_sample_min_num_frames = 16
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 448
self.tile_sample_stride_width = 448
self.tile_sample_stride_num_frames = 8
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_min_num_frames: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
tile_sample_stride_num_frames: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
return enc
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = IdentityDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[Tuple[torch.Tensor], DecoderOutput]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
......@@ -744,6 +744,17 @@ class DiagonalGaussianDistribution(object):
return self.mean
class IdentityDistribution(object):
def __init__(self, parameters: torch.Tensor):
self.parameters = parameters
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
return self.parameters
def mode(self) -> torch.Tensor:
return self.parameters
class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
......
......@@ -1204,7 +1204,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen and CogView4
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
......
......@@ -19,6 +19,7 @@ if is_torch_available():
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
......
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from ...configuration_utils import ConfigMixin, register_to_config
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
class CosmosPatchEmbed(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * patch_size[2], out_channels, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
hidden_states = hidden_states.reshape(
batch_size, num_channels, num_frames // p_t, p_t, height // p_h, p_h, width // p_w, p_w
)
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7)
hidden_states = self.proj(hidden_states)
return hidden_states
class CosmosTimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.linear_1 = nn.Linear(in_features, out_features, bias=False)
self.activation = nn.SiLU()
self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(timesteps)
emb = self.activation(emb)
emb = self.linear_2(emb)
return emb
class CosmosEmbedding(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int) -> None:
super().__init__()
self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
temb = self.t_embedder(timesteps_proj)
embedded_timestep = self.norm(timesteps_proj)
return temb, embedded_timestep
class CosmosAdaLayerNorm(nn.Module):
def __init__(self, in_features: int, hidden_features: int) -> None:
super().__init__()
self.embedding_dim = in_features
self.activation = nn.SiLU()
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False)
def forward(
self, hidden_states: torch.Tensor, embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
embedded_timestep = self.activation(embedded_timestep)
embedded_timestep = self.linear_1(embedded_timestep)
embedded_timestep = self.linear_2(embedded_timestep)
if temb is not None:
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
shift, scale = embedded_timestep.chunk(2, dim=1)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return hidden_states
class CosmosAdaLayerNormZero(nn.Module):
def __init__(self, in_features: int, hidden_features: Optional[int] = None) -> None:
super().__init__()
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
self.activation = nn.SiLU()
if hidden_features is None:
self.linear_1 = nn.Identity()
else:
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
embedded_timestep: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
embedded_timestep = self.activation(embedded_timestep)
embedded_timestep = self.linear_1(embedded_timestep)
embedded_timestep = self.linear_2(embedded_timestep)
if temb is not None:
embedded_timestep = embedded_timestep + temb
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return hidden_states, gate
class CosmosAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 1. QKV projections
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
query = attn.norm_q(query)
key = attn.norm_k(key)
# 3. Apply RoPE
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
# 6. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CosmosTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
mlp_ratio: float = 4.0,
adaln_lora_dim: int = 256,
qk_norm: str = "rms_norm",
out_bias: bool = False,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.attn1 = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.attn2 = Attention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
embedded_timestep: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
extra_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_pos_emb is not None:
hidden_states = hidden_states + extra_pos_emb
# 1. Self Attention
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
# 2. Cross Attention
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
attn_output = self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
# 3. Feed Forward
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
return hidden_states
class CosmosRotaryPosEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int] = (128, 240, 240),
patch_size: Tuple[int, int, int] = (1, 2, 2),
base_fps: int = 24,
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.base_fps = base_fps
self.dim_h = hidden_size // 6 * 2
self.dim_w = hidden_size // 6 * 2
self.dim_t = hidden_size - self.dim_h - self.dim_w
self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
device = hidden_states.device
h_theta = 10000.0 * self.h_ntk_factor
w_theta = 10000.0 * self.w_ntk_factor
t_theta = 10000.0 * self.t_ntk_factor
seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
dim_h_range = (
torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
)
dim_w_range = (
torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
)
dim_t_range = (
torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
)
h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
temporal_freqs = 1.0 / (t_theta**dim_t_range)
emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
# Apply sequence scaling in temporal dimension
if fps is None:
# Images
emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
else:
# Videos
emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
class CosmosLearnablePositionalEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
eps: float = 1e-6,
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.eps = eps
self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
emb = emb_t + emb_h + emb_w
emb = emb.flatten(1, 3)
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
return (emb / norm).type_as(hidden_states)
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each attention head.
num_layers (`int`, defaults to `28`):
The number of layers of transformer blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
adaln_lora_dim (`int`, defaults to `256`):
The hidden dimension of the Adaptive LayerNorm LoRA layer.
max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
The maximum size of the input latent tensors in the temporal, height, and width dimensions.
patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
The patch size to use for patchifying the input latent tensors in the temporal, height, and width
dimensions.
rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
The scaling factor to use for RoPE in the temporal, height, and width dimensions.
concat_padding_mask (`bool`, defaults to `True`):
Whether to concatenate the padding mask to the input latent tensors.
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
num_layers: int = 28,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
max_size: Tuple[int, int, int] = (128, 240, 240),
patch_size: Tuple[int, int, int] = (1, 2, 2),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
concat_padding_mask: bool = True,
extra_pos_embed_type: Optional[str] = "learnable",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
# 1. Patch Embedding
patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)
# 2. Positional Embedding
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)
self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=hidden_size,
max_size=max_size,
patch_size=patch_size,
)
# 3. Time Embedding
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
# 4. Transformer Blocks
self.transformer_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
)
for _ in range(num_layers)
]
)
# 5. Output norm & projection
self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
self.proj_out = nn.Linear(
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
condition_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
# 1. Concatenate padding mask if needed & prepare attention mask
if condition_mask is not None:
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
if self.config.concat_padding_mask:
padding_mask = transforms.functional.resize(
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
# 2. Generate positional embeddings
image_rotary_emb = self.rope(hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
# 3. Patchify input
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
# 4. Timestep embeddings
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
# 5. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
)
# 6. Output norm & projection & unpatchify
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
# Another few hours of sanity lost to the void.
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
......@@ -156,6 +156,8 @@ else:
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
......@@ -546,6 +548,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
from ...schedulers import EDMEulerScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CosmosTextToWorldPipeline
>>> from diffusers.utils import export_to_video
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
>>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
>>> output = pipe(prompt=prompt).frames[0]
>>> export_to_video(output, "output.mp4", fps=30)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CosmosTextToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Cosmos uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLCosmos`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLCosmos,
scheduler: EDMEulerScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
return_length=True,
return_offsets_mapping=False,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
lengths = prompt_attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
prompt_embeds[i, length:] = 0
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 16,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents * self.scheduler.config.sigma_max
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
num_inference_steps: int = 36,
guidance_scale: float = 7.0,
fps: int = 30,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`.
fps (`int`, defaults to `30`):
The frames per second of the generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
self.safety_checker.to("cpu")
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
# 5. Prepare latent variables
transformer_dtype = self.transformer.dtype
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = latent_model_input.to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
fps=fps,
padding_mask=padding_mask,
return_dict=False,
)[0]
sample = latents
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
fps=fps,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
sample = torch.cat([sample, sample])
# pred_original_sample (x0)
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
self.scheduler._step_index -= 1
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
# pred_sample (eps)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
if self.vae.config.latents_mean is not None:
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
latents_mean = (
torch.tensor(latents_mean)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents_std = (
torch.tensor(latents_std)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
else:
latents = latents / self.scheduler.config.sigma_data
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
if self.safety_checker is not None:
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
self.safety_checker.to("cpu")
else:
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
from ...schedulers import EDMEulerScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
Image conditioning:
```python
>>> import torch
>>> from diffusers import CosmosVideoToWorldPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
... )
>>> video = pipe(image=image, prompt=prompt).frames[0]
>>> export_to_video(video, "output.mp4", fps=30)
```
Video conditioning:
```python
>>> import torch
>>> from diffusers import CosmosVideoToWorldPipeline
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.transformer = torch.compile(pipe.transformer)
>>> pipe.to("cuda")
>>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
>>> video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
... )[
... :21
... ] # This example uses only the first 21 frames
>>> video = pipe(video=video, prompt=prompt).frames[0]
>>> export_to_video(video, "output.mp4", fps=30)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class CosmosVideoToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Cosmos uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLCosmos`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLCosmos,
scheduler: EDMEulerScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
return_length=True,
return_offsets_mapping=False,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
lengths = prompt_attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
prompt_embeds[i, length:] = 0
return prompt_embeds
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
video: torch.Tensor,
batch_size: int,
num_channels_latents: 16,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
do_classifier_free_guidance: bool = True,
input_frames_guidance: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_cond_frames = video.size(2)
if num_cond_frames >= num_frames:
# Take the last `num_frames` frames for conditioning
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
video = video[:, :, -num_frames:]
else:
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
num_padding_frames = num_frames - num_cond_frames
padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4))
video = torch.cat([video, padding], dim=2)
if isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
if self.vae.config.latents_mean is not None:
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
latents_mean = (
torch.tensor(latents_mean)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
.to(init_latents)
)
latents_std = (
torch.tensor(latents_std)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
.to(init_latents)
)
init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std
else:
init_latents = init_latents * self.scheduler.config.sigma_data
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
latents = latents * self.scheduler.config.sigma_max
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
uncond_indicator = uncond_mask = None
if do_classifier_free_guidance:
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
uncond_mask = zeros_padding
if not input_frames_guidance:
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
image=None,
video=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if image is None and video is None:
raise ValueError("Either `image` or `video` has to be provided.")
if image is not None and video is not None:
raise ValueError("Only one of `image` or `video` has to be provided.")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput = None,
video: List[PipelineImageInput] = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
num_inference_steps: int = 36,
guidance_scale: float = 7.0,
input_frames_guidance: bool = False,
augment_sigma: float = 0.001,
fps: int = 30,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`.
fps (`int`, defaults to `30`):
The frames per second of the generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
self.safety_checker.to("cpu")
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
# 5. Prepare latent variables
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
if image is not None:
video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
else:
video = self.video_processor.preprocess_video(video, height, width)
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
self.do_classifier_free_guidance,
input_frames_guidance,
torch.float32,
device,
generator,
latents,
)
cond_mask = cond_mask.to(transformer_dtype)
if self.do_classifier_free_guidance:
uncond_mask = uncond_mask.to(transformer_dtype)
augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
current_sigma = self.scheduler.sigmas[i]
is_augment_sigma_greater = augment_sigma >= current_sigma
c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma)
c_in_original = self.scheduler._get_conditioning_c_in(current_sigma)
current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
cond_latent = cond_latent * c_in_augment / c_in_original
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
cond_latent = cond_latent.to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=cond_latent,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
fps=fps,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
sample = latents
if self.do_classifier_free_guidance:
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
uncond_latent = uncond_latent * c_in_augment / c_in_original
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
uncond_latent = uncond_latent.to(transformer_dtype)
noise_pred_uncond = self.transformer(
hidden_states=uncond_latent,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
fps=fps,
condition_mask=uncond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
sample = torch.cat([sample, sample])
# pred_original_sample (x0)
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
self.scheduler._step_index -= 1
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred_uncond = (
current_uncond_indicator * conditioning_latents
+ (1 - current_uncond_indicator) * noise_pred_uncond
)
noise_pred_cond = (
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
)
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = (
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
)
# pred_sample (eps)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
if self.vae.config.latents_mean is not None:
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
latents_mean = (
torch.tensor(latents_mean)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents_std = (
torch.tensor(latents_std)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
else:
latents = latents / self.scheduler.config.sigma_data
video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0]
if self.safety_checker is not None:
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
self.safety_checker.to("cpu")
else:
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class CosmosPipelineOutput(BaseOutput):
r"""
Output class for Cosmos pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
......@@ -144,7 +144,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
......@@ -568,5 +568,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
return self.config.num_train_timesteps
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment