Unverified Commit 9f91305f authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Cosmos Predict2 (#11695)

* support text-to-image

* update example

* make fix-copies

* support use_flow_sigmas in EDM scheduler instead of maintain cosmos-specific scheduler

* support video-to-world

* update

* rename text2image pipeline

* make fix-copies

* add t2i test

* add test for v2w pipeline

* support edm dpmsolver multistep

* update

* update

* update

* update tests

* fix tests

* safety checker

* make conversion script work without guardrail
parent 368958df
...@@ -36,6 +36,22 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) ...@@ -36,6 +36,22 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all - all
- __call__ - __call__
## Cosmos2TextToImagePipeline
[[autodoc]] Cosmos2TextToImagePipeline
- all
- __call__
## Cosmos2VideoToWorldPipeline
[[autodoc]] Cosmos2VideoToWorldPipeline
- all
- __call__
## CosmosPipelineOutput ## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput [[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
## CosmosImagePipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput
...@@ -7,7 +7,17 @@ from accelerate import init_empty_weights ...@@ -7,7 +7,17 @@ from accelerate import init_empty_weights
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler from diffusers import (
AutoencoderKLCosmos,
AutoencoderKLWan,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
CosmosTransformer3DModel,
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
)
def remove_keys_(key: str, state_dict: Dict[str, Any]): def remove_keys_(key: str, state_dict: Dict[str, Any]):
...@@ -29,7 +39,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]): ...@@ -29,7 +39,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
state_dict[new_key] = state_dict.pop(key) state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = { TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
"t_embedder.1": "time_embed.t_embedder", "t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm", "affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1", ".blocks.0.block.attn": ".attn1",
...@@ -56,7 +66,7 @@ TRANSFORMER_KEYS_RENAME_DICT = { ...@@ -56,7 +66,7 @@ TRANSFORMER_KEYS_RENAME_DICT = {
"final_layer.linear": "proj_out", "final_layer.linear": "proj_out",
} }
TRANSFORMER_SPECIAL_KEYS_REMAP = { TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
"blocks.block": rename_transformer_blocks_, "blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_, "logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_, "logvar.0.phases": remove_keys_,
...@@ -64,6 +74,45 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = { ...@@ -64,6 +74,45 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
"pos_embedder.seq": remove_keys_, "pos_embedder.seq": remove_keys_,
} }
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"t_embedder.1": "time_embed.t_embedder",
"t_embedding_norm": "time_embed.norm",
"blocks": "transformer_blocks",
"adaln_modulation_self_attn.1": "norm1.linear_1",
"adaln_modulation_self_attn.2": "norm1.linear_2",
"adaln_modulation_cross_attn.1": "norm2.linear_1",
"adaln_modulation_cross_attn.2": "norm2.linear_2",
"adaln_modulation_mlp.1": "norm3.linear_1",
"adaln_modulation_mlp.2": "norm3.linear_2",
"self_attn": "attn1",
"cross_attn": "attn2",
"q_proj": "to_q",
"k_proj": "to_k",
"v_proj": "to_v",
"output_proj": "to_out.0",
"q_norm": "norm_q",
"k_norm": "norm_k",
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
# "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
"accum_video_sample_counter": remove_keys_,
"accum_image_sample_counter": remove_keys_,
"accum_iteration": remove_keys_,
"accum_train_in_hours": remove_keys_,
"pos_embedder.seq": remove_keys_,
"pos_embedder.dim_spatial_range": remove_keys_,
"pos_embedder.dim_temporal_range": remove_keys_,
"_extra_state": remove_keys_,
}
TRANSFORMER_CONFIGS = { TRANSFORMER_CONFIGS = {
"Cosmos-1.0-Diffusion-7B-Text2World": { "Cosmos-1.0-Diffusion-7B-Text2World": {
"in_channels": 16, "in_channels": 16,
...@@ -125,6 +174,66 @@ TRANSFORMER_CONFIGS = { ...@@ -125,6 +174,66 @@ TRANSFORMER_CONFIGS = {
"concat_padding_mask": True, "concat_padding_mask": True,
"extra_pos_embed_type": "learnable", "extra_pos_embed_type": "learnable",
}, },
"Cosmos-2.0-Diffusion-2B-Text2Image": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 4.0, 4.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.0-Diffusion-14B-Text2Image": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 4.0, 4.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.0-Diffusion-2B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.0-Diffusion-14B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (20 / 24, 2.0, 2.0),
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
} }
VAE_KEYS_RENAME_DICT = { VAE_KEYS_RENAME_DICT = {
...@@ -216,9 +325,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: ...@@ -216,9 +325,18 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict return state_dict
def convert_transformer(transformer_type: str, ckpt_path: str): def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
PREFIX_KEY = "net." PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
if "Cosmos-1.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False
with init_empty_weights(): with init_empty_weights():
config = TRANSFORMER_CONFIGS[transformer_type] config = TRANSFORMER_CONFIGS[transformer_type]
...@@ -281,13 +399,61 @@ def convert_vae(vae_type: str): ...@@ -281,13 +399,61 @@ def convert_vae(vae_type: str):
return vae return vae
def save_pipeline_cosmos_1_0(args, transformer, vae):
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
# So, the sigma_min values that is used is the default value of 0.002.
scheduler = EDMEulerScheduler(
sigma_min=0.002,
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)
pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
pipe = pipe_cls(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos_2_0(args, transformer, vae):
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
pipe = pipe_cls(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys())) parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument( parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
) )
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE") parser.add_argument(
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b") parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b") parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--save_pipeline", action="store_true")
...@@ -316,37 +482,26 @@ if __name__ == "__main__": ...@@ -316,37 +482,26 @@ if __name__ == "__main__":
assert args.tokenizer_path is not None assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None: if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path) weights_only = "Cosmos-1.0" in args.transformer_type
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
transformer = transformer.to(dtype=dtype) transformer = transformer.to(dtype=dtype)
if not args.save_pipeline: if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_type is not None: if args.vae_type is not None:
vae = convert_vae(args.vae_type) if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
else:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
if not args.save_pipeline: if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline: if args.save_pipeline:
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype) if "Cosmos-1.0" in args.transformer_type:
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path) save_pipeline_cosmos_1_0(args, transformer, vae)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly. elif "Cosmos-2.0" in args.transformer_type:
# So, the sigma_min values that is used is the default value of 0.002. save_pipeline_cosmos_2_0(args, transformer, vae)
scheduler = EDMEulerScheduler( else:
sigma_min=0.002, assert False
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)
pipe = CosmosTextToWorldPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
...@@ -361,6 +361,8 @@ else: ...@@ -361,6 +361,8 @@ else:
"CogView4ControlPipeline", "CogView4ControlPipeline",
"CogView4Pipeline", "CogView4Pipeline",
"ConsisIDPipeline", "ConsisIDPipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline", "CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline", "CosmosVideoToWorldPipeline",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
...@@ -949,6 +951,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -949,6 +951,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4ControlPipeline, CogView4ControlPipeline,
CogView4Pipeline, CogView4Pipeline,
ConsisIDPipeline, ConsisIDPipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline, CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline, CosmosVideoToWorldPipeline,
CycleDiffusionPipeline, CycleDiffusionPipeline,
......
...@@ -100,11 +100,15 @@ class CosmosAdaLayerNorm(nn.Module): ...@@ -100,11 +100,15 @@ class CosmosAdaLayerNorm(nn.Module):
embedded_timestep = self.linear_2(embedded_timestep) embedded_timestep = self.linear_2(embedded_timestep)
if temb is not None: if temb is not None:
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim] embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
shift, scale = embedded_timestep.chunk(2, dim=1) shift, scale = embedded_timestep.chunk(2, dim=-1)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
if embedded_timestep.ndim == 2:
shift, scale = (x.unsqueeze(1) for x in (shift, scale))
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states return hidden_states
...@@ -135,9 +139,13 @@ class CosmosAdaLayerNormZero(nn.Module): ...@@ -135,9 +139,13 @@ class CosmosAdaLayerNormZero(nn.Module):
if temb is not None: if temb is not None:
embedded_timestep = embedded_timestep + temb embedded_timestep = embedded_timestep + temb
shift, scale, gate = embedded_timestep.chunk(3, dim=1) shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
if embedded_timestep.ndim == 2:
shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states, gate return hidden_states, gate
...@@ -255,19 +263,19 @@ class CosmosTransformerBlock(nn.Module): ...@@ -255,19 +263,19 @@ class CosmosTransformerBlock(nn.Module):
# 1. Self Attention # 1. Self Attention
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb) norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb) attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output hidden_states = hidden_states + gate * attn_output
# 2. Cross Attention # 2. Cross Attention
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb) norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
attn_output = self.attn2( attn_output = self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
) )
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output hidden_states = hidden_states + gate * attn_output
# 3. Feed Forward # 3. Feed Forward
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb) norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output hidden_states = hidden_states + gate * ff_output
return hidden_states return hidden_states
...@@ -513,7 +521,23 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -513,7 +521,23 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C] hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
# 4. Timestep embeddings # 4. Timestep embeddings
temb, embedded_timestep = self.time_embed(hidden_states, timestep) if timestep.ndim == 1:
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
elif timestep.ndim == 5:
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
)
timestep = timestep.flatten()
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
# We can do this because num_frames == post_patch_num_frames, as p_t is 1
temb, embedded_timestep = (
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
.expand(-1, -1, post_patch_height, post_patch_width, -1)
.flatten(1, 3)
for x in (temb, embedded_timestep)
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
else:
assert False
# 5. Transformer blocks # 5. Transformer blocks
for block in self.transformer_blocks: for block in self.transformer_blocks:
...@@ -544,8 +568,8 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin): ...@@ -544,8 +568,8 @@ class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1)) hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width)) hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order? # NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
# Another few hours of sanity lost to the void. # It might be a source of confusion to the reader, but this is correct
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5) hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
......
...@@ -157,7 +157,12 @@ else: ...@@ -157,7 +157,12 @@ else:
_import_structure["cogview3"] = ["CogView3PlusPipeline"] _import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"] _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"] _import_structure["cosmos"] = [
"Cosmos2TextToImagePipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
"Cosmos2VideoToWorldPipeline",
]
_import_structure["controlnet"].extend( _import_structure["controlnet"].extend(
[ [
"BlipDiffusionControlNetPipeline", "BlipDiffusionControlNetPipeline",
...@@ -559,7 +564,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -559,7 +564,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetXSPipeline, StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline,
) )
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline from .cosmos import (
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline,
)
from .deepfloyd_if import ( from .deepfloyd_if import (
IFImg2ImgPipeline, IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline, IFImg2ImgSuperResolutionPipeline,
......
...@@ -22,6 +22,8 @@ except OptionalDependencyNotAvailable: ...@@ -22,6 +22,8 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"] _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"] _import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
...@@ -33,6 +35,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -33,6 +35,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
......
This diff is collapsed.
This diff is collapsed.
...@@ -131,7 +131,7 @@ def retrieve_timesteps( ...@@ -131,7 +131,7 @@ def retrieve_timesteps(
class CosmosTextToWorldPipeline(DiffusionPipeline): class CosmosTextToWorldPipeline(DiffusionPipeline):
r""" r"""
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos). Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.). implemented for all pipelines (downloading, saving, running on a particular device, etc.).
...@@ -426,12 +426,12 @@ class CosmosTextToWorldPipeline(DiffusionPipeline): ...@@ -426,12 +426,12 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, defaults to `1280`): width (`int`, defaults to `1280`):
The width in pixels of the generated image. The width in pixels of the generated image.
num_frames (`int`, defaults to `129`): num_frames (`int`, defaults to `121`):
The number of frames in the generated video. The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`): num_inference_steps (`int`, defaults to `36`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
guidance_scale (`float`, defaults to `6.0`): guidance_scale (`float`, defaults to `7.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
...@@ -457,9 +457,6 @@ class CosmosTextToWorldPipeline(DiffusionPipeline): ...@@ -457,9 +457,6 @@ class CosmosTextToWorldPipeline(DiffusionPipeline):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self: each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
......
...@@ -174,7 +174,8 @@ def retrieve_latents( ...@@ -174,7 +174,8 @@ def retrieve_latents(
class CosmosVideoToWorldPipeline(DiffusionPipeline): class CosmosVideoToWorldPipeline(DiffusionPipeline):
r""" r"""
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos). Pipeline for image-to-world and video-to-world generation using [Cosmos
Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.). implemented for all pipelines (downloading, saving, running on a particular device, etc.).
...@@ -541,12 +542,12 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline): ...@@ -541,12 +542,12 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, defaults to `1280`): width (`int`, defaults to `1280`):
The width in pixels of the generated image. The width in pixels of the generated image.
num_frames (`int`, defaults to `129`): num_frames (`int`, defaults to `121`):
The number of frames in the generated video. The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`): num_inference_steps (`int`, defaults to `36`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
guidance_scale (`float`, defaults to `6.0`): guidance_scale (`float`, defaults to `7.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
...@@ -572,9 +573,6 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline): ...@@ -572,9 +573,6 @@ class CosmosVideoToWorldPipeline(DiffusionPipeline):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self: each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
import torch import torch
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput, get_logger
logger = get_logger(__name__)
@dataclass @dataclass
class CosmosPipelineOutput(BaseOutput): class CosmosPipelineOutput(BaseOutput):
r""" r"""
Output class for Cosmos pipelines. Output class for Cosmos any-to-world/video pipelines.
Args: Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
...@@ -18,3 +24,17 @@ class CosmosPipelineOutput(BaseOutput): ...@@ -18,3 +24,17 @@ class CosmosPipelineOutput(BaseOutput):
""" """
frames: torch.Tensor frames: torch.Tensor
@dataclass
class CosmosImagePipelineOutput(BaseOutput):
"""
Output class for Cosmos any-to-image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
...@@ -407,6 +407,36 @@ class ConsisIDPipeline(metaclass=DummyObject): ...@@ -407,6 +407,36 @@ class ConsisIDPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class Cosmos2TextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Cosmos2VideoToWorldPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CosmosTextToWorldPipeline(metaclass=DummyObject): class CosmosTextToWorldPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLWan,
Cosmos2TextToImagePipeline,
CosmosTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2TextToImagePipelineWrapper(Cosmos2TextToImagePipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
return Cosmos2TextToImagePipeline.from_pretrained(*args, **kwargs)
class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2TextToImagePipelineWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CosmosTransformer3DModel(
in_channels=16,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_video = torch.randn(3, 32, 32)
max_diff = np.abs(generated_image - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import tempfile
import unittest
import numpy as np
import PIL.Image
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLWan,
Cosmos2VideoToWorldPipeline,
CosmosTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
enable_full_determinism()
class Cosmos2VideoToWorldPipelineWrapper(Cosmos2VideoToWorldPipeline):
@staticmethod
def from_pretrained(*args, **kwargs):
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
return Cosmos2VideoToWorldPipeline.from_pretrained(*args, **kwargs)
class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Cosmos2VideoToWorldPipelineWrapper
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image", "video"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CosmosTransformer3DModel(
in_channels=16 + 1,
out_channels=16,
num_attention_heads=2,
attention_head_dim=16,
num_layers=2,
mlp_ratio=2,
text_embed_dim=32,
adaln_lora_dim=4,
max_size=(4, 32, 32),
patch_size=(1, 2, 2),
rope_scale=(2.0, 1.0, 1.0),
concat_padding_mask=True,
extra_pos_embed_type="learnable",
)
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
"safety_checker": DummyCosmosSafetyChecker(),
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 32
image_width = 32
image = PIL.Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": image_height,
"width": image_width,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
expected_video = torch.randn(9, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_components_function(self):
init_components = self.get_dummy_components()
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
self.pipeline_class._optional_components.remove("safety_checker")
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
self.pipeline_class._optional_components.append("safety_checker")
def test_serialization_with_variants(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
model_components = [
component_name
for component_name, component in pipe.components.items()
if isinstance(component, torch.nn.Module)
]
model_components.remove("safety_checker")
variant = "fp16"
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
with open(f"{tmpdir}/model_index.json", "r") as f:
config = json.load(f)
for subfolder in os.listdir(tmpdir):
if not os.path.isfile(subfolder) and subfolder in model_components:
folder_path = os.path.join(tmpdir, subfolder)
is_folder = os.path.isdir(folder_path) and subfolder in config
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
pipe.save_pretrained(tmpdirname, safe_serialization=False)
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
loaded_pipe = self.pipeline_class.from_pretrained(
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
)
for name, component in loaded_pipe.components.items():
if name == "safety_checker":
continue
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
self.assertEqual(
component.dtype,
expected_dtype,
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
@unittest.skip(
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
"too large and slow to run on CI."
)
def test_encode_prompt_works_in_isolation(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment