Unverified Commit e8fc8b1f authored by kakukakujirori's avatar kakukakujirori Committed by GitHub
Browse files

Bug fix in LTXImageToVideoPipeline.prepare_latents() when latents is already set (#10918)



* Bug fix in ltx

* Assume packed latents.

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent d6f4774c
...@@ -487,18 +487,20 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ...@@ -487,18 +487,20 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
) -> torch.Tensor: ) -> torch.Tensor:
height = height // self.vae_spatial_compression_ratio height = height // self.vae_spatial_compression_ratio
width = width // self.vae_spatial_compression_ratio width = width // self.vae_spatial_compression_ratio
num_frames = ( num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)
shape = (batch_size, num_channels_latents, num_frames, height, width) shape = (batch_size, num_channels_latents, num_frames, height, width)
mask_shape = (batch_size, 1, num_frames, height, width) mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None: if latents is not None:
conditioning_mask = latents.new_zeros(shape) conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0 conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents( conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
) )
return latents.to(device=device, dtype=dtype), conditioning_mask return latents.to(device=device, dtype=dtype), conditioning_mask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment