"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "bc7a4d4917456afd70913be85bd25c556c25862c"
Unverified Commit 13e48492 authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

[LTX0.9.5] Refactor `LTXConditionPipeline` for text-only conditioning (#11174)

* Refactor `LTXConditionPipeline` to add text-only conditioning

* style

* up

* Refactor `LTXConditionPipeline` to streamline condition handling and improve clarity

* Improve condition checks

* Simplify latents handling based on conditioning type

* Refactor rope_interpolation_scale preparation for clarity and efficiency

* Update LTXConditionPipeline docstring to clarify supported input types

* Add LTX Video 0.9.5 model to documentation

* Clarify documentation to indicate support for text-only conditioning without passing `conditions`

* refactor: comment out unused parameters in LTXConditionPipeline

* fix: restore previously commented parameters in LTXConditionPipeline

* fix: remove unused parameters from LTXConditionPipeline

* refactor: remove unnecessary lines in LTXConditionPipeline
parent 94f2c48d
...@@ -32,6 +32,7 @@ Available models: ...@@ -32,6 +32,7 @@ Available models:
|:-------------:|:-----------------:| |:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | | [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | | [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -75,6 +75,7 @@ EXAMPLE_DOC_STRING = """ ...@@ -75,6 +75,7 @@ EXAMPLE_DOC_STRING = """
>>> # Generate video >>> # Generate video
>>> generator = torch.Generator("cuda").manual_seed(0) >>> generator = torch.Generator("cuda").manual_seed(0)
>>> # Text-only conditioning is also supported without the need to pass `conditions`
>>> video = pipe( >>> video = pipe(
... conditions=[condition1, condition2], ... conditions=[condition1, condition2],
... prompt=prompt, ... prompt=prompt,
...@@ -223,7 +224,7 @@ def retrieve_latents( ...@@ -223,7 +224,7 @@ def retrieve_latents(
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r""" r"""
Pipeline for image-to-video generation. Pipeline for text/image/video-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video Reference: https://github.com/Lightricks/LTX-Video
...@@ -482,9 +483,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -482,9 +483,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if conditions is not None and (image is not None or video is not None): if conditions is not None and (image is not None or video is not None):
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
if conditions is None and (image is None and video is None):
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
if conditions is None: if conditions is None:
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
raise ValueError( raise ValueError(
...@@ -642,9 +640,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -642,9 +640,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
def prepare_latents( def prepare_latents(
self, self,
conditions: List[torch.Tensor], conditions: Optional[List[torch.Tensor]] = None,
condition_strength: List[float], condition_strength: Optional[List[float]] = None,
condition_frame_index: List[int], condition_frame_index: Optional[List[int]] = None,
batch_size: int = 1, batch_size: int = 1,
num_channels_latents: int = 128, num_channels_latents: int = 128,
height: int = 512, height: int = 512,
...@@ -654,7 +652,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -654,7 +652,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> None: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio
...@@ -662,77 +660,80 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -662,77 +660,80 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) if len(conditions) > 0:
condition_latent_frames_mask = torch.zeros(
extra_conditioning_latents = [] (batch_size, num_latent_frames), device=device, dtype=torch.float32
extra_conditioning_video_ids = [] )
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
condition_latents, self.vae.latents_mean, self.vae.latents_std
).to(device, dtype=dtype)
num_data_frames = data.size(2)
num_cond_frames = condition_latents.size(2)
if frame_index == 0:
latents[:, :, :num_cond_frames] = torch.lerp(
latents[:, :, :num_cond_frames], condition_latents, strength
)
condition_latent_frames_mask[:, :num_cond_frames] = strength
else: extra_conditioning_latents = []
if num_data_frames > 1: extra_conditioning_video_ids = []
if num_cond_frames < num_prefix_latent_frames: extra_conditioning_mask = []
raise ValueError( extra_conditioning_num_latents = 0
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
) condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
if num_cond_frames > num_prefix_latent_frames: condition_latents, self.vae.latents_mean, self.vae.latents_std
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames ).to(device, dtype=dtype)
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
latents[:, :, start_frame:end_frame] = torch.lerp( num_data_frames = data.size(2)
latents[:, :, start_frame:end_frame], num_cond_frames = condition_latents.size(2)
condition_latents[:, :, num_prefix_latent_frames:],
strength, if frame_index == 0:
) latents[:, :, :num_cond_frames] = torch.lerp(
condition_latent_frames_mask[:, start_frame:end_frame] = strength latents[:, :, :num_cond_frames], condition_latents, strength
condition_latents = condition_latents[:, :, :num_prefix_latent_frames] )
condition_latent_frames_mask[:, :num_cond_frames] = strength
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, strength) else:
if num_data_frames > 1:
condition_video_ids = self._prepare_video_ids( if num_cond_frames < num_prefix_latent_frames:
batch_size, raise ValueError(
condition_latents.size(2), f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
latent_height, )
latent_width,
patch_size=self.transformer_spatial_patch_size, if num_cond_frames > num_prefix_latent_frames:
patch_size_t=self.transformer_temporal_patch_size, start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
device=device, end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
) latents[:, :, start_frame:end_frame] = torch.lerp(
condition_video_ids = self._scale_video_ids( latents[:, :, start_frame:end_frame],
condition_video_ids, condition_latents[:, :, num_prefix_latent_frames:],
scale_factor=self.vae_spatial_compression_ratio, strength,
scale_factor_t=self.vae_temporal_compression_ratio, )
frame_index=frame_index, condition_latent_frames_mask[:, start_frame:end_frame] = strength
device=device, condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
)
condition_latents = self._pack_latents( noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents, condition_latents = torch.lerp(noise, condition_latents, strength)
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size, condition_video_ids = self._prepare_video_ids(
) batch_size,
condition_conditioning_mask = torch.full( condition_latents.size(2),
condition_latents.shape[:2], strength, device=device, dtype=dtype latent_height,
) latent_width,
patch_size=self.transformer_spatial_patch_size,
patch_size_t=self.transformer_temporal_patch_size,
device=device,
)
condition_video_ids = self._scale_video_ids(
condition_video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=frame_index,
device=device,
)
condition_latents = self._pack_latents(
condition_latents,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
condition_conditioning_mask = torch.full(
condition_latents.shape[:2], strength, device=device, dtype=dtype
)
extra_conditioning_latents.append(condition_latents) extra_conditioning_latents.append(condition_latents)
extra_conditioning_video_ids.append(condition_video_ids) extra_conditioning_video_ids.append(condition_video_ids)
extra_conditioning_mask.append(condition_conditioning_mask) extra_conditioning_mask.append(condition_conditioning_mask)
extra_conditioning_num_latents += condition_latents.size(1) extra_conditioning_num_latents += condition_latents.size(1)
video_ids = self._prepare_video_ids( video_ids = self._prepare_video_ids(
batch_size, batch_size,
...@@ -743,7 +744,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -743,7 +744,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
patch_size=self.transformer_spatial_patch_size, patch_size=self.transformer_spatial_patch_size,
device=device, device=device,
) )
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) if len(conditions) > 0:
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
else:
conditioning_mask, extra_conditioning_num_latents = None, 0
video_ids = self._scale_video_ids( video_ids = self._scale_video_ids(
video_ids, video_ids,
scale_factor=self.vae_spatial_compression_ratio, scale_factor=self.vae_spatial_compression_ratio,
...@@ -755,7 +759,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -755,7 +759,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
) )
if len(extra_conditioning_latents) > 0: if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1) latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
...@@ -955,7 +959,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -955,7 +959,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
frame_index = [condition.frame_index for condition in conditions] frame_index = [condition.frame_index for condition in conditions]
image = [condition.image for condition in conditions] image = [condition.image for condition in conditions]
video = [condition.video for condition in conditions] video = [condition.video for condition in conditions]
else: elif image is not None or video is not None:
if not isinstance(image, list): if not isinstance(image, list):
image = [image] image = [image]
num_conditions = 1 num_conditions = 1
...@@ -999,32 +1003,34 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -999,32 +1003,34 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
vae_dtype = self.vae.dtype vae_dtype = self.vae.dtype
conditioning_tensors = [] conditioning_tensors = []
for condition_image, condition_video, condition_frame_index, condition_strength in zip( is_conditioning_image_or_video = image is not None or video is not None
image, video, frame_index, strength if is_conditioning_image_or_video:
): for condition_image, condition_video, condition_frame_index, condition_strength in zip(
if condition_image is not None: image, video, frame_index, strength
condition_tensor = ( ):
self.video_processor.preprocess(condition_image, height, width) if condition_image is not None:
.unsqueeze(2) condition_tensor = (
.to(device, dtype=vae_dtype) self.video_processor.preprocess(condition_image, height, width)
) .unsqueeze(2)
elif condition_video is not None: .to(device, dtype=vae_dtype)
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) )
num_frames_input = condition_tensor.size(2) elif condition_video is not None:
num_frames_output = self.trim_conditioning_sequence( condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
condition_frame_index, num_frames_input, num_frames num_frames_input = condition_tensor.size(2)
) num_frames_output = self.trim_conditioning_sequence(
condition_tensor = condition_tensor[:, :, :num_frames_output] condition_frame_index, num_frames_input, num_frames
condition_tensor = condition_tensor.to(device, dtype=vae_dtype) )
else: condition_tensor = condition_tensor[:, :, :num_frames_output]
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
else:
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: raise ValueError("Either `image` or `video` must be provided for conditioning.")
raise ValueError(
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
f"but got {condition_tensor.size(2)} frames." raise ValueError(
) f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
conditioning_tensors.append(condition_tensor) f"but got {condition_tensor.size(2)} frames."
)
conditioning_tensors.append(condition_tensor)
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels num_channels_latents = self.transformer.config.in_channels
...@@ -1045,7 +1051,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1045,7 +1051,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video_coords = video_coords.float() video_coords = video_coords.float()
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
init_latents = latents.clone() init_latents = latents.clone() if is_conditioning_image_or_video else None
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0) video_coords = torch.cat([video_coords, video_coords], dim=0)
...@@ -1065,7 +1071,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1065,7 +1071,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 7. Denoising loop # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
...@@ -1073,7 +1079,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1073,7 +1079,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
self._current_timestep = t self._current_timestep = t
if image_cond_noise_scale > 0: if image_cond_noise_scale > 0 and init_latents is not None:
# Add timestep-dependent noise to the hard-conditioning latents # Add timestep-dependent noise to the hard-conditioning latents
# This helps with motion continuity, especially when conditioned on a single frame # This helps with motion continuity, especially when conditioned on a single frame
latents = self.add_noise_to_image_conditioning_latents( latents = self.add_noise_to_image_conditioning_latents(
...@@ -1086,16 +1092,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1086,16 +1092,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
) )
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
conditioning_mask_model_input = ( if is_conditioning_image_or_video:
torch.cat([conditioning_mask, conditioning_mask]) conditioning_mask_model_input = (
if self.do_classifier_free_guidance torch.cat([conditioning_mask, conditioning_mask])
else conditioning_mask if self.do_classifier_free_guidance
) else conditioning_mask
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype) latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
noise_pred = self.transformer( noise_pred = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
...@@ -1115,8 +1123,11 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1115,8 +1123,11 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
denoised_latents = self.scheduler.step( denoised_latents = self.scheduler.step(
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
)[0] )[0]
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) if is_conditioning_image_or_video:
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
else:
latents = denoised_latents
if callback_on_step_end is not None: if callback_on_step_end is not None:
callback_kwargs = {} callback_kwargs = {}
...@@ -1134,7 +1145,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL ...@@ -1134,7 +1145,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
latents = latents[:, extra_conditioning_num_latents:] if is_conditioning_image_or_video:
latents = latents[:, extra_conditioning_num_latents:]
latents = self._unpack_latents( latents = self._unpack_latents(
latents, latents,
latent_num_frames, latent_num_frames,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment