Unverified Commit 8cabd4a0 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

[pipeline] CogVideoX-Fun Control (#9671)



* cogvideox-fun control

* make style

* make fix-copies

* karras schedulers

* Update src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* apply suggestions from review

---------
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5783286d
...@@ -36,6 +36,10 @@ There are two models available that can be used with the text-to-video and video ...@@ -36,6 +36,10 @@ There are two models available that can be used with the text-to-video and video
There is one model available that can be used with the image-to-video CogVideoX pipeline: There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
## Inference ## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
...@@ -118,6 +122,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc ...@@ -118,6 +122,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- all - all
- __call__ - __call__
## CogVideoXFunControlPipeline
[[autodoc]] CogVideoXFunControlPipeline
- all
- __call__
## CogVideoXPipelineOutput ## CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput [[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput
...@@ -256,6 +256,7 @@ else: ...@@ -256,6 +256,7 @@ else:
"BlipDiffusionControlNetPipeline", "BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline", "BlipDiffusionPipeline",
"CLIPImageProjection", "CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline", "CogVideoXImageToVideoPipeline",
"CogVideoXPipeline", "CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline", "CogVideoXVideoToVideoPipeline",
...@@ -711,6 +712,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -711,6 +712,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDMPipeline, AudioLDMPipeline,
AuraFlowPipeline, AuraFlowPipeline,
CLIPImageProjection, CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline, CogVideoXImageToVideoPipeline,
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXVideoToVideoPipeline, CogVideoXVideoToVideoPipeline,
......
...@@ -144,6 +144,7 @@ else: ...@@ -144,6 +144,7 @@ else:
"CogVideoXPipeline", "CogVideoXPipeline",
"CogVideoXImageToVideoPipeline", "CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline", "CogVideoXVideoToVideoPipeline",
"CogVideoXFunControlPipeline",
] ]
_import_structure["cogview3"] = ["CogView3PlusPipeline"] _import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["controlnet"].extend( _import_structure["controlnet"].extend(
...@@ -470,7 +471,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -470,7 +471,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
) )
from .aura_flow import AuraFlowPipeline from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
)
from .cogview3 import CogView3PlusPipeline from .cogview3 import CogView3PlusPipeline
from .controlnet import ( from .controlnet import (
BlipDiffusionControlNetPipeline, BlipDiffusionControlNetPipeline,
......
...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: ...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
_import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .pipeline_cogvideox import CogVideoXPipeline from .pipeline_cogvideox import CogVideoXPipeline
from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
......
...@@ -272,6 +272,21 @@ class CLIPImageProjection(metaclass=DummyObject): ...@@ -272,6 +272,21 @@ class CLIPImageProjection(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class CogVideoXFunControlPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CogVideoXImageToVideoPipeline(metaclass=DummyObject): class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
)
enable_full_determinism()
class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = CogVideoXFunControlPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"control_video"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
# But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel
# to be 32. The internal dim is product of num_attention_heads and attention_head_dim
num_attention_heads=4,
attention_head_dim=8,
in_channels=8,
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1,
sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
patch_size=2,
temporal_compression_ratio=4,
max_text_seq_length=16,
)
torch.manual_seed(0)
vae = AutoencoderKLCogVideoX(
in_channels=3,
out_channels=3,
down_block_types=(
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types=(
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
latent_channels=4,
layers_per_block=1,
norm_num_groups=2,
temporal_compression_ratio=4,
)
torch.manual_seed(0)
scheduler = DDIMScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
# Cannot reduce because convolution kernel becomes bigger than sample
height = 16
width = 16
control_video = [Image.new("RGB", (width, height))] * num_frames
inputs = {
"prompt": "dance monkey",
"negative_prompt": "",
"control_video": control_video,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"height": height,
"width": width,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (8, 3, 16, 16))
expected_video = torch.randn(8, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.5):
# NOTE(aryan): This requires a higher expected_max_diff than other CogVideoX pipelines
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_fused = frames[0, -2:, -1, -3:, -3:]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment