Unverified Commit 8907a70a authored by Aryan's avatar Aryan Committed by GitHub
Browse files

New HunyuanVideo-I2V (#11066)

* update

* update

* update

* add tests

* update docs

* raise value error

* warning for true cfg and guidance scale

* fix test
parent 5dbe4f5d
...@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline: ...@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
| Model name | Description | | Model name | Description |
|:---|:---| |:---|:---|
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. | | [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) | | [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
## Quantization ## Quantization
......
...@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = { ...@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768, "pooled_projection_dim": 768,
"rope_theta": 256.0, "rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56), "rope_axes_dim": (16, 56, 56),
"image_condition_type": None,
}, },
"HYVideo-T/2-I2V": { "HYVideo-T/2-I2V-33ch": {
"in_channels": 16 * 2 + 1, "in_channels": 16 * 2 + 1,
"out_channels": 16, "out_channels": 16,
"num_attention_heads": 24, "num_attention_heads": 24,
...@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = { ...@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768, "pooled_projection_dim": 768,
"rope_theta": 256.0, "rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56), "rope_axes_dim": (16, 56, 56),
"image_condition_type": "latent_concat",
},
"HYVideo-T/2-I2V-16ch": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "token_replace",
}, },
} }
......
...@@ -54,6 +54,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -54,6 +54,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel >>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import load_image, export_to_video >>> from diffusers.utils import load_image, export_to_video
>>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
>>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V" >>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
...@@ -69,7 +70,12 @@ EXAMPLE_DOC_STRING = """ ...@@ -69,7 +70,12 @@ EXAMPLE_DOC_STRING = """
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png" ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
... ) ... )
>>> output = pipe(image=image, prompt=prompt).frames[0] >>> # If using hunyuanvideo-community/HunyuanVideo-I2V
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
>>> export_to_video(output, "output.mp4", fps=15) >>> export_to_video(output, "output.mp4", fps=15)
``` ```
""" """
...@@ -399,7 +405,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -399,7 +405,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256, max_sequence_length: int = 256,
): image_embed_interleave: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
image, image,
...@@ -409,6 +416,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -409,6 +416,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device=device, device=device,
dtype=dtype, dtype=dtype,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
image_embed_interleave=image_embed_interleave,
) )
if pooled_prompt_embeds is None: if pooled_prompt_embeds is None:
...@@ -433,6 +441,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -433,6 +441,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_embeds=None, prompt_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
prompt_template=None, prompt_template=None,
true_cfg_scale=1.0,
guidance_scale=1.0,
): ):
if height % 16 != 0 or width % 16 != 0: if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
...@@ -471,6 +481,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -471,6 +481,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
) )
if true_cfg_scale > 1.0 and guidance_scale > 1.0:
logger.warning(
"Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
"classifier-free guidance and embedded-guidance to be applied. This is not recommended "
"as it may lead to higher memory usage, slower inference and potentially worse results."
)
def prepare_latents( def prepare_latents(
self, self,
image: torch.Tensor, image: torch.Tensor,
...@@ -483,6 +500,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -483,6 +500,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None, latents: Optional[torch.Tensor] = None,
image_condition_type: str = "latent_concat",
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
...@@ -497,10 +515,11 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -497,10 +515,11 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
image = image.unsqueeze(2) # [B, C, 1, H, W] image = image.unsqueeze(2) # [B, C, 1, H, W]
if isinstance(generator, list): if isinstance(generator, list):
image_latents = [ image_latents = [
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size) retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
for i in range(batch_size)
] ]
else: else:
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image] image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
...@@ -513,6 +532,9 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -513,6 +532,9 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
t = torch.tensor([0.999]).to(device=device) t = torch.tensor([0.999]).to(device=device)
latents = latents * t + image_latents * (1 - t) latents = latents * t + image_latents * (1 - t)
if image_condition_type == "token_replace":
image_latents = image_latents[:, :, :1]
return latents, image_latents return latents, image_latents
def enable_vae_slicing(self): def enable_vae_slicing(self):
...@@ -598,6 +620,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -598,6 +620,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256, max_sequence_length: int = 256,
image_embed_interleave: Optional[int] = None,
): ):
r""" r"""
The call function to the pipeline for generation. The call function to the pipeline for generation.
...@@ -704,12 +727,22 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -704,12 +727,22 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_embeds, prompt_embeds,
callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs,
prompt_template, prompt_template,
true_cfg_scale,
guidance_scale,
) )
image_condition_type = self.transformer.config.image_condition_type
has_neg_prompt = negative_prompt is not None or ( has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
) )
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
image_embed_interleave = (
image_embed_interleave
if image_embed_interleave is not None
else (
2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1
)
)
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
...@@ -729,7 +762,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -729,7 +762,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
# 3. Prepare latent variables # 3. Prepare latent variables
vae_dtype = self.vae.dtype vae_dtype = self.vae.dtype
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
if image_condition_type == "latent_concat":
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
elif image_condition_type == "token_replace":
num_channels_latents = self.transformer.config.in_channels
latents, image_latents = self.prepare_latents( latents, image_latents = self.prepare_latents(
image_tensor, image_tensor,
batch_size * num_videos_per_prompt, batch_size * num_videos_per_prompt,
...@@ -741,10 +779,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -741,10 +779,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device, device,
generator, generator,
latents, latents,
image_condition_type,
) )
image_latents[:, :, 1:] = 0 if image_condition_type == "latent_concat":
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:]) image_latents[:, :, 1:] = 0
mask[:, :, 1:] = 0 mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
mask[:, :, 1:] = 0
# 4. Encode input prompt # 4. Encode input prompt
transformer_dtype = self.transformer.dtype transformer_dtype = self.transformer.dtype
...@@ -759,6 +799,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -759,6 +799,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_attention_mask=prompt_attention_mask, prompt_attention_mask=prompt_attention_mask,
device=device, device=device,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
image_embed_interleave=image_embed_interleave,
) )
prompt_embeds = prompt_embeds.to(transformer_dtype) prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
...@@ -782,10 +823,17 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -782,10 +823,17 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps # 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
# 6. Prepare guidance condition
guidance = None
if self.transformer.config.guidance_embeds:
guidance = (
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
)
# 7. Denoising loop # 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
...@@ -796,16 +844,21 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -796,16 +844,21 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
continue continue
self._current_timestep = t self._current_timestep = t
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
if image_condition_type == "latent_concat":
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
elif image_condition_type == "token_replace":
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask, encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds, pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -817,13 +870,20 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -817,13 +870,20 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask, encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds, pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs, attention_kwargs=attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if image_condition_type == "latent_concat":
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
elif image_condition_type == "token_replace":
latents = latents = self.scheduler.step(
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
)[0]
latents = torch.cat([image_latents, latents], dim=2)
if callback_on_step_end is not None: if callback_on_step_end is not None:
callback_kwargs = {} callback_kwargs = {}
...@@ -844,12 +904,16 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader ...@@ -844,12 +904,16 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
self._current_timestep = None self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]
video = video[:, :, 4:, :, :] if image_condition_type == "latent_concat":
video = video[:, :, 4:, :, :]
video = self.video_processor.postprocess_video(video, output_type=output_type) video = self.video_processor.postprocess_video(video, output_type=output_type)
else: else:
video = latents[:, :, 1:, :, :] if image_condition_type == "latent_concat":
video = latents[:, :, 1:, :, :]
else:
video = latents
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
......
...@@ -80,6 +80,7 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): ...@@ -80,6 +80,7 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"text_embed_dim": 16, "text_embed_dim": 16,
"pooled_projection_dim": 8, "pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4), "rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -144,6 +145,7 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T ...@@ -144,6 +145,7 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
"text_embed_dim": 16, "text_embed_dim": 16,
"pooled_projection_dim": 8, "pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4), "rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -209,6 +211,75 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test ...@@ -209,6 +211,75 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
"text_embed_dim": 16, "text_embed_dim": 16,
"pooled_projection_dim": 8, "pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4), "rope_axes_dim": (2, 4, 4),
"image_condition_type": "latent_concat",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 2
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
@property
def input_shape(self):
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 2,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "token_replace",
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
......
...@@ -83,6 +83,7 @@ class HunyuanVideoImageToVideoPipelineFastTests( ...@@ -83,6 +83,7 @@ class HunyuanVideoImageToVideoPipelineFastTests(
text_embed_dim=16, text_embed_dim=16,
pooled_projection_dim=8, pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4), rope_axes_dim=(2, 4, 4),
image_condition_type="latent_concat",
) )
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment