Unverified Commit 2e5203be authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Hunyuan I2V (#10983)

* update

* update

* update

* add tests

* update

* add model tests

* update docs

* update

* update example

* fix defaults

* update
parent d55f4110
...@@ -49,7 +49,8 @@ The following models are available for the image-to-video pipeline: ...@@ -49,7 +49,8 @@ The following models are available for the image-to-video pipeline:
| Model name | Description | | Model name | Description |
|:---|:---| |:---|:---|
| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | | [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
## Quantization ## Quantization
......
...@@ -3,11 +3,19 @@ from typing import Any, Dict ...@@ -3,11 +3,19 @@ from typing import Any, Dict
import torch import torch
from accelerate import init_empty_weights from accelerate import init_empty_weights
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer from transformers import (
AutoModel,
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
LlavaForConditionalGeneration,
)
from diffusers import ( from diffusers import (
AutoencoderKLHunyuanVideo, AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline, HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
) )
...@@ -134,6 +142,46 @@ VAE_KEYS_RENAME_DICT = {} ...@@ -134,6 +142,46 @@ VAE_KEYS_RENAME_DICT = {}
VAE_SPECIAL_KEYS_REMAP = {} VAE_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_CONFIGS = {
"HYVideo-T/2-cfgdistill": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
},
"HYVideo-T/2-I2V": {
"in_channels": 16 * 2 + 1,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": False,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
},
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key) state_dict[new_key] = state_dict.pop(old_key)
...@@ -149,11 +197,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: ...@@ -149,11 +197,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict return state_dict
def convert_transformer(ckpt_path: str): def convert_transformer(ckpt_path: str, transformer_type: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
config = TRANSFORMER_CONFIGS[transformer_type]
with init_empty_weights(): with init_empty_weights():
transformer = HunyuanVideoTransformer3DModel() transformer = HunyuanVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
...@@ -205,6 +254,10 @@ def get_args(): ...@@ -205,6 +254,10 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument(
"--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
)
parser.add_argument("--flow_shift", type=float, default=7.0)
return parser.parse_args() return parser.parse_args()
...@@ -228,7 +281,7 @@ if __name__ == "__main__": ...@@ -228,7 +281,7 @@ if __name__ == "__main__":
assert args.text_encoder_2_path is not None assert args.text_encoder_2_path is not None
if args.transformer_ckpt_path is not None: if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path) transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
transformer = transformer.to(dtype=dtype) transformer = transformer.to(dtype=dtype)
if not args.save_pipeline: if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
...@@ -239,11 +292,12 @@ if __name__ == "__main__": ...@@ -239,11 +292,12 @@ if __name__ == "__main__":
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline: if args.save_pipeline:
if args.transformer_type == "HYVideo-T/2-cfgdistill":
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16) text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right") tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16) text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path) tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
pipe = HunyuanVideoPipeline( pipe = HunyuanVideoPipeline(
transformer=transformer, transformer=transformer,
...@@ -255,3 +309,24 @@ if __name__ == "__main__": ...@@ -255,3 +309,24 @@ if __name__ == "__main__":
scheduler=scheduler, scheduler=scheduler,
) )
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
else:
text_encoder = LlavaForConditionalGeneration.from_pretrained(
args.text_encoder_path, torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
pipe = HunyuanVideoImageToVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
image_processor=image_processor,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
...@@ -313,6 +313,7 @@ else: ...@@ -313,6 +313,7 @@ else:
"HunyuanDiTPAGPipeline", "HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline", "HunyuanVideoPipeline",
"I2VGenXLPipeline", "I2VGenXLPipeline",
"IFImg2ImgPipeline", "IFImg2ImgPipeline",
...@@ -823,6 +824,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -823,6 +824,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline, HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline, HunyuanVideoPipeline,
I2VGenXLPipeline, I2VGenXLPipeline,
IFImg2ImgPipeline, IFImg2ImgPipeline,
......
...@@ -581,7 +581,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -581,7 +581,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
self.context_embedder = HunyuanVideoTokenRefiner( self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
) )
if guidance_embeds:
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
else:
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
# 2. RoPE # 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
...@@ -708,7 +712,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, ...@@ -708,7 +712,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
image_rotary_emb = self.rope(hidden_states) image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings # 2. Conditional embeddings
if self.config.guidance_embeds:
temb = self.time_text_embed(timestep, guidance, pooled_projections) temb = self.time_text_embed(timestep, guidance, pooled_projections)
else:
temb = self.time_text_embed(timestep, pooled_projections)
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
......
...@@ -222,7 +222,11 @@ else: ...@@ -222,7 +222,11 @@ else:
"EasyAnimateControlPipeline", "EasyAnimateControlPipeline",
] ]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"] _import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
]
_import_structure["kandinsky"] = [ _import_structure["kandinsky"] = [
"KandinskyCombinedPipeline", "KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgCombinedPipeline",
...@@ -570,7 +574,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -570,7 +574,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline, FluxPriorReduxPipeline,
ReduxImageEncoder, ReduxImageEncoder,
) )
from .hunyuan_video import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoPipeline from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
from .hunyuandit import HunyuanDiTPipeline from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import ( from .kandinsky import (
......
...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable: ...@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else: else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"] _import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else: else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline from .pipeline_hunyuan_video import HunyuanVideoPipeline
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
else: else:
import sys import sys
......
...@@ -677,6 +677,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject): ...@@ -677,6 +677,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoPipeline(metaclass=DummyObject): class HunyuanVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -154,3 +154,68 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T ...@@ -154,3 +154,68 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 2 * 4 + 1
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
}
@property
def input_shape(self):
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 2 * 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": False,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
LlamaConfig,
LlamaModel,
LlamaTokenizer,
)
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism()
class HunyuanVideoImageToVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase
):
pipeline_class = HunyuanVideoImageToVideoPipeline
params = frozenset(
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
)
batch_params = frozenset(["prompt", "image"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideoTransformer3DModel(
in_channels=2 * 4 + 1,
out_channels=4,
num_attention_heads=2,
attention_head_dim=10,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=1,
patch_size_t=1,
guidance_embeds=False,
text_embed_dim=16,
pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4),
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo(
in_channels=3,
out_channels=3,
latent_channels=4,
down_block_types=(
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
up_block_types=(
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
layers_per_block=1,
act_fn="silu",
norm_num_groups=4,
scaling_factor=0.476986,
spatial_compression_ratio=8,
temporal_compression_ratio=4,
mid_block_add_attention=True,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
llama_text_encoder_config = LlamaConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=16,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = LlamaModel(llama_text_encoder_config)
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
image_processor = CLIPImageProcessor(
crop_size=336,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=336,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 16
image_width = 16
image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"prompt_template": {
"template": "{}",
"crop_start": 0,
},
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.5,
"height": image_height,
"width": image_width,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
# NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
expected_video = torch.randn(5, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
# Seems to require higher tolerance than the other tests
expected_diff_max = 0.6
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
@unittest.skip(
"Encode prompt currently does not work in isolation because of requiring image embeddings from image processor. The test does not handle this case, or we need to rewrite encode_prompt."
)
def test_encode_prompt_works_in_isolation(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment