Unverified Commit e3bc4aab authored by Aryan's avatar Aryan Committed by GitHub
Browse files

SkyReels Hunyuan T2V & I2V (#10837)



* update

* make fix-copies

* update

* tests

* update

* update

* add co-author
Co-Authored-By: default avatarLangdx <82783347+Langdx@users.noreply.github.com>

* add co-author
Co-Authored-By: default avatarhowe <howezhang2018@gmail.com>

* update

---------
Co-authored-by: default avatarLangdx <82783347+Langdx@users.noreply.github.com>
Co-authored-by: default avatarhowe <howezhang2018@gmail.com>
parent f0707751
...@@ -32,6 +32,21 @@ Recommendations for inference: ...@@ -32,6 +32,21 @@ Recommendations for inference:
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo. - For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/). - For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
## Available models
The following models are available for the [`HunyuanVideoPipeline`](text-to-video) pipeline:
| Model name | Description |
|:---|:---|
| [`hunyuanvideo-community/HunyuanVideo`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | Official HunyuanVideo (guidance-distilled). Performs best at multiple resolutions and frames. Performs best with `guidance_scale=6.0`, `true_cfg_scale=1.0` and without a negative prompt. |
| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
The following models are available for the image-to-video pipeline:
| Model name | Description |
|:---|:---|
| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
## Quantization ## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model. Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
......
...@@ -305,6 +305,7 @@ else: ...@@ -305,6 +305,7 @@ else:
"HunyuanDiTControlNetPipeline", "HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline", "HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline", "HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoPipeline", "HunyuanVideoPipeline",
"I2VGenXLPipeline", "I2VGenXLPipeline",
"IFImg2ImgPipeline", "IFImg2ImgPipeline",
...@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -804,6 +805,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline, HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline, HunyuanDiTPAGPipeline,
HunyuanDiTPipeline, HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoPipeline, HunyuanVideoPipeline,
I2VGenXLPipeline, I2VGenXLPipeline,
IFImg2ImgPipeline, IFImg2ImgPipeline,
......
...@@ -217,7 +217,7 @@ else: ...@@ -217,7 +217,7 @@ else:
"IFSuperResolutionPipeline", "IFSuperResolutionPipeline",
] ]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"] _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["kandinsky"] = [ _import_structure["kandinsky"] = [
"KandinskyCombinedPipeline", "KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgCombinedPipeline",
...@@ -558,7 +558,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -558,7 +558,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline, FluxPriorReduxPipeline,
ReduxImageEncoder, ReduxImageEncoder,
) )
from .hunyuan_video import HunyuanVideoPipeline from .hunyuan_video import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoPipeline
from .hunyuandit import HunyuanDiTPipeline from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import ( from .kandinsky import (
......
...@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable: ...@@ -22,6 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"] _import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -32,6 +33,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -32,6 +33,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline from .pipeline_hunyuan_video import HunyuanVideoPipeline
else: else:
......
...@@ -325,7 +325,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -325,7 +325,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
) )
if pooled_prompt_embeds is None: if pooled_prompt_embeds is None:
if prompt_2 is None and pooled_prompt_embeds is None: if prompt_2 is None:
prompt_2 = prompt prompt_2 = prompt
pooled_prompt_embeds = self._get_clip_prompt_embeds( pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt, prompt,
...@@ -470,11 +470,14 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -470,11 +470,14 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Union[str, List[str]] = None,
height: int = 720, height: int = 720,
width: int = 1280, width: int = 1280,
num_frames: int = 129, num_frames: int = 129,
num_inference_steps: int = 50, num_inference_steps: int = 50,
sigmas: List[float] = None, sigmas: List[float] = None,
true_cfg_scale: float = 1.0,
guidance_scale: float = 6.0, guidance_scale: float = 6.0,
num_videos_per_prompt: Optional[int] = 1, num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
...@@ -482,6 +485,9 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -482,6 +485,9 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
prompt_embeds: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None, pooled_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -502,6 +508,13 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -502,6 +508,13 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
prompt_2 (`str` or `List[str]`, *optional*): prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead. will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
height (`int`, defaults to `720`): height (`int`, defaults to `720`):
The height in pixels of the generated image. The height in pixels of the generated image.
width (`int`, defaults to `1280`): width (`int`, defaults to `1280`):
...@@ -515,6 +528,8 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -515,6 +528,8 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used. will be used.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
guidance_scale (`float`, defaults to `6.0`): guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen `guidance_scale` is defined as `w` of equation 2. of [Imagen
...@@ -535,6 +550,17 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -535,6 +550,17 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
prompt_embeds (`torch.Tensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. provided, text embeddings are generated from the `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -579,6 +605,11 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -579,6 +605,11 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
prompt_template, prompt_template,
) )
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None self._current_timestep = None
...@@ -595,6 +626,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -595,6 +626,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt # 3. Encode input prompt
transformer_dtype = self.transformer.dtype
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt=prompt, prompt=prompt,
prompt_2=prompt_2, prompt_2=prompt_2,
...@@ -606,21 +638,29 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -606,21 +638,29 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
device=device, device=device,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
) )
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
if pooled_prompt_embeds is not None:
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
if do_true_cfg:
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_template=prompt_template,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
prompt_attention_mask=negative_prompt_attention_mask,
device=device,
max_sequence_length=max_sequence_length,
)
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps # 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
)
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels num_channels_latents = self.transformer.config.in_channels
...@@ -664,6 +704,19 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): ...@@ -664,6 +704,19 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
return_dict=False, return_dict=False,
)[0] )[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
......
...@@ -617,6 +617,21 @@ class HunyuanDiTPipeline(metaclass=DummyObject): ...@@ -617,6 +617,21 @@ class HunyuanDiTPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoPipeline(metaclass=DummyObject): class HunyuanVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
...@@ -87,3 +87,70 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): ...@@ -87,3 +87,70 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
def test_gradient_checkpointing_is_applied(self): def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"} expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 8
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
@property
def input_shape(self):
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 8,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, LlamaConfig, LlamaModel, LlamaTokenizer
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
enable_full_determinism()
class HunyuanSkyreelsImageToVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, unittest.TestCase
):
pipeline_class = HunyuanSkyreelsImageToVideoPipeline
params = frozenset(
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
)
batch_params = frozenset(["prompt", "image"])
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = HunyuanVideoTransformer3DModel(
in_channels=8,
out_channels=4,
num_attention_heads=2,
attention_head_dim=10,
num_layers=num_layers,
num_single_layers=num_single_layers,
num_refiner_layers=1,
patch_size=1,
patch_size_t=1,
guidance_embeds=True,
text_embed_dim=16,
pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4),
)
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo(
in_channels=3,
out_channels=3,
latent_channels=4,
down_block_types=(
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
"HunyuanVideoDownBlock3D",
),
up_block_types=(
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
layers_per_block=1,
act_fn="silu",
norm_num_groups=4,
scaling_factor=0.476986,
spatial_compression_ratio=8,
temporal_compression_ratio=4,
mid_block_add_attention=True,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
llama_text_encoder_config = LlamaConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=16,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = LlamaModel(llama_text_encoder_config)
tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image_height = 16
image_width = 16
image = Image.new("RGB", (image_width, image_height))
inputs = {
"image": image,
"prompt": "dance monkey",
"prompt_template": {
"template": "{}",
"crop_start": 0,
},
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.5,
"height": 16,
"width": 16,
# 4 * k + 1 is the recommendation
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
# Seems to require higher tolerance than the other tests
expected_diff_max = 0.6
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment