Unverified Commit 2e83cbbb authored by Aryan's avatar Aryan Committed by GitHub
Browse files

LTX 0.9.5 (#10968)



* update


---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 33d10af2
...@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24) ...@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all - all
- __call__ - __call__
## LTXConditionPipeline
[[autodoc]] LTXConditionPipeline
- all
- __call__
## LTXPipelineOutput ## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
...@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = { ...@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
"last_scale_shift_table": "scale_shift_table", "last_scale_shift_table": "scale_shift_table",
} }
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = { VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_, "per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_,
...@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = { ...@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
"model.diffusion_model": remove_keys_, "model.diffusion_model": remove_keys_,
} }
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict state_dict = saved_dict
...@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: ...@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer( def convert_transformer(
ckpt_path: str, ckpt_path: str,
dtype: torch.dtype, dtype: torch.dtype,
version: str = "0.9.0",
): ):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path)) original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights(): with init_empty_weights():
transformer = LTXVideoTransformer3DModel() transformer = LTXVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
...@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: ...@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3, "out_channels": 3,
"latent_channels": 128, "latent_channels": 128,
"block_out_channels": (128, 256, 512, 512), "block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (128, 256, 512, 512), "decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4), "layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False), "spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False), "decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (False, False, False, False), "upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1), "upsample_factor": (1, 1, 1, 1),
"patch_size": 4, "patch_size": 4,
...@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: ...@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3, "out_channels": 3,
"latent_channels": 128, "latent_channels": 128,
"block_out_channels": (128, 256, 512, 512), "block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024), "decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4), "layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8), "decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False), "spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True), "decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False), "decoder_inject_noise": (True, True, True, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (True, True, True), "upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2), "upsample_factor": (2, 2, 2),
"timestep_conditioning": True, "timestep_conditioning": True,
...@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]: ...@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False, "decoder_causal": False,
} }
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) elif version == "0.9.5":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config return config
...@@ -223,7 +294,7 @@ def get_args(): ...@@ -223,7 +294,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument( parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
) )
return parser.parse_args() return parser.parse_args()
...@@ -277,14 +348,17 @@ if __name__ == "__main__": ...@@ -277,14 +348,17 @@ if __name__ == "__main__":
for param in text_encoder.parameters(): for param in text_encoder.parameters():
param.data = param.data.contiguous() param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler( if args.version == "0.9.5":
use_dynamic_shifting=True, scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
base_shift=0.95, else:
max_shift=2.05, scheduler = FlowMatchEulerDiscreteScheduler(
base_image_seq_len=1024, use_dynamic_shifting=True,
max_image_seq_len=4096, base_shift=0.95,
shift_terminal=0.1, max_shift=2.05,
) base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline( pipe = LTXPipeline(
scheduler=scheduler, scheduler=scheduler,
......
...@@ -402,6 +402,7 @@ else: ...@@ -402,6 +402,7 @@ else:
"LDMTextToImagePipeline", "LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL", "LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline", "LTXImageToVideoPipeline",
"LTXPipeline", "LTXPipeline",
"Lumina2Pipeline", "Lumina2Pipeline",
...@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline, LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL, LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline, LTXImageToVideoPipeline,
LTXPipeline, LTXPipeline,
Lumina2Pipeline, Lumina2Pipeline,
......
...@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module): ...@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
return hidden_states return hidden_states
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
class LTXVideoUpsampler3d(nn.Module): class LTXVideoUpsampler3d(nn.Module):
def __init__( def __init__(
self, self,
...@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module): ...@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
is_causal: bool = True, is_causal: bool = True,
residual: bool = False, residual: bool = False,
upscale_factor: int = 1, upscale_factor: int = 1,
padding_mode: str = "zeros",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module): ...@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
is_causal=is_causal, is_causal=is_causal,
padding_mode=padding_mode,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module): ...@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
return hidden_states return hidden_states
class LTXVideo095DownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXVideoMidBlock3d(nn.Module): class LTXVideoMidBlock3d(nn.Module):
r""" r"""
...@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module): ...@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 128, out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4, patch_size: int = 4,
patch_size_t: int = 1, patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6, resnet_norm_eps: float = 1e-6,
...@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module): ...@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
) )
# down blocks # down blocks
num_block_out_channels = len(block_out_channels) is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels): for i in range(num_block_out_channels):
input_channel = output_channel input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] if not is_ltx_095:
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
down_block = LTXVideoDownBlock3D( else:
in_channels=input_channel, output_channel = block_out_channels[i + 1]
out_channels=output_channel,
num_layers=layers_per_block[i], if down_block_types[i] == "LTXVideoDownBlock3D":
resnet_eps=resnet_norm_eps, down_block = LTXVideoDownBlock3D(
spatio_temporal_scale=spatio_temporal_scaling[i], in_channels=input_channel,
is_causal=is_causal, out_channels=output_channel,
) num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
elif down_block_types[i] == "LTXVideo095DownBlock3D":
down_block = LTXVideo095DownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module): ...@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
# timestep embedding # timestep embedding
self.time_embedder = None self.time_embedder = None
self.scale_shift_table = None self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning: if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
...@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module): ...@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
...@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3, out_channels: int = 3,
latent_channels: int = 128, latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False, timestep_conditioning: bool = False,
...@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
scaling_factor: float = 1.0, scaling_factor: float = 1.0,
encoder_causal: bool = True, encoder_causal: bool = True,
decoder_causal: bool = False, decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
in_channels=in_channels, in_channels=in_channels,
out_channels=latent_channels, out_channels=latent_channels,
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
down_block_types=down_block_types,
spatio_temporal_scaling=spatio_temporal_scaling, spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
downsample_type=downsample_type,
patch_size=patch_size, patch_size=patch_size,
patch_size_t=patch_size_t, patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps, resnet_norm_eps=resnet_norm_eps,
...@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) self.spatial_compression_ratio = (
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time. # to perform decoding of a single video latent at a time.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module): ...@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
self.patch_size_t = patch_size_t self.patch_size_t = patch_size_t
self.theta = theta self.theta = theta
def forward( def _prepare_video_coords(
self, self,
hidden_states: torch.Tensor, batch_size: int,
num_frames: int, num_frames: int,
height: int, height: int,
width: int, width: int,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, rope_interpolation_scale: Tuple[torch.Tensor, float, float],
) -> Tuple[torch.Tensor, torch.Tensor]: device: torch.device,
batch_size = hidden_states.size(0) ) -> torch.Tensor:
# Always compute rope in fp32 # Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0) grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
...@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module): ...@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
grid = grid.flatten(2, 4).transpose(1, 2) grid = grid.flatten(2, 4).transpose(1, 2)
return grid
def forward(
self,
hidden_states: torch.Tensor,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
video_coords: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
if video_coords is None:
grid = self._prepare_video_coords(
batch_size,
num_frames,
height,
width,
rope_interpolation_scale=rope_interpolation_scale,
device=hidden_states.device,
)
else:
grid = torch.stack(
[
video_coords[:, 0] / self.base_num_frames,
video_coords[:, 1] / self.base_height,
video_coords[:, 2] / self.base_width,
],
dim=-1,
)
start = 1.0 start = 1.0
end = self.theta end = self.theta
freqs = self.theta ** torch.linspace( freqs = self.theta ** torch.linspace(
...@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor, timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor, encoder_attention_mask: torch.Tensor,
num_frames: int, num_frames: Optional[int] = None,
height: int, height: Optional[int] = None,
width: int, width: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
video_coords: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
) )
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask # convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
......
...@@ -264,7 +264,7 @@ else: ...@@ -264,7 +264,7 @@ else:
] ]
) )
_import_structure["latte"] = ["LattePipeline"] _import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend( _import_structure["marigold"].extend(
...@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL, LEditsPPPipelineStableDiffusionXL,
) )
from .ltx import LTXImageToVideoPipeline, LTXPipeline from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import ( from .marigold import (
......
...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: ...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"] _import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .pipeline_ltx import LTXPipeline from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline
else: else:
......
...@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ...@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions # 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = ( rope_interpolation_scale = (
1 / latent_frame_rate, self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
) )
......
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import PIL.Image
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from ...models.autoencoders import AutoencoderKLLTXVideo
from ...models.transformers import LTXVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import LTXPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
>>> from diffusers.utils import export_to_video, load_video, load_image
>>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> # Load input image and video
>>> video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
... )
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
... )
>>> # Create conditioning objects
>>> condition1 = LTXVideoCondition(
... image=image,
... frame_index=0,
... )
>>> condition2 = LTXVideoCondition(
... video=video,
... frame_index=80,
... )
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> # Generate video
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> video = pipe(
... conditions=[condition1, condition2],
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=768,
... height=512,
... num_frames=161,
... num_inference_steps=40,
... generator=generator,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=24)
```
"""
@dataclass
class LTXVideoCondition:
"""
Defines a single frame-conditioning item for LTX Video - a single frame or a sequence of frames.
Attributes:
image (`PIL.Image.Image`):
The image to condition the video on.
video (`List[PIL.Image.Image]`):
The video to condition the video on.
frame_index (`int`):
The frame index at which the image or video will conditionally effect the video generation.
strength (`float`, defaults to `1.0`):
The strength of the conditioning effect. A value of `1.0` means the conditioning effect is fully applied.
"""
image: Optional[PIL.Image.Image] = None
video: Optional[List[PIL.Image.Image]] = None
frame_index: int = 0
strength: float = 1.0
# from LTX-Video/ltx_video/schedulers/rf.py
def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
if num_steps < 2:
return torch.tensor([1.0])
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
const = quadratic_coef * (linear_steps**2)
quadratic_sigma_schedule = [
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.tensor(sigma_schedule[:-1])
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for image-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
Args:
transformer ([`LTXVideoTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLLTXVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTXVideo,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: LTXVideoTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.transformer_spatial_patch_size = (
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
)
self.transformer_temporal_patch_size = (
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
)
self.default_height = 512
self.default_width = 704
self.default_frames = 121
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def check_inputs(
self,
prompt,
conditions,
image,
video,
frame_index,
strength,
height,
width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
if conditions is not None and (image is not None or video is not None):
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
if conditions is None and (image is None and video is None):
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
if conditions is None:
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
raise ValueError(
"If `conditions` is not provided, `image` and `frame_index` must be of the same length."
)
elif isinstance(image, list) and isinstance(strength, list) and len(image) != len(strength):
raise ValueError("If `conditions` is not provided, `image` and `strength` must be of the same length.")
elif isinstance(video, list) and isinstance(frame_index, list) and len(video) != len(frame_index):
raise ValueError(
"If `conditions` is not provided, `video` and `frame_index` must be of the same length."
)
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
@staticmethod
def _prepare_video_ids(
batch_size: int,
num_frames: int,
height: int,
width: int,
patch_size: int = 1,
patch_size_t: int = 1,
device: torch.device = None,
) -> torch.Tensor:
latent_sample_coords = torch.meshgrid(
torch.arange(0, num_frames, patch_size_t, device=device),
torch.arange(0, height, patch_size, device=device),
torch.arange(0, width, patch_size, device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = latent_coords.reshape(batch_size, -1, num_frames * height * width)
return latent_coords
@staticmethod
def _scale_video_ids(
video_ids: torch.Tensor,
scale_factor: int = 32,
scale_factor_t: int = 8,
frame_index: int = 0,
device: torch.device = None,
) -> torch.Tensor:
scaled_latent_coords = (
video_ids
* torch.tensor([scale_factor_t, scale_factor, scale_factor], device=video_ids.device)[None, :, None]
)
scaled_latent_coords[:, 0] = (scaled_latent_coords[:, 0] + 1 - scale_factor_t).clamp(min=0)
scaled_latent_coords[:, 0] += frame_index
return scaled_latent_coords
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def trim_conditioning_sequence(self, start_frame: int, sequence_num_frames: int, target_num_frames: int):
"""
Trim a conditioning sequence to the allowed number of frames.
Args:
start_frame (int): The target frame number of the first frame in the sequence.
sequence_num_frames (int): The number of frames in the sequence.
target_num_frames (int): The target number of frames in the generated video.
Returns:
int: updated sequence length
"""
scale_factor = self.vae_temporal_compression_ratio
num_frames = min(sequence_num_frames, target_num_frames - start_frame)
# Trim down to a multiple of temporal_scale_factor frames plus 1
num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
return num_frames
@staticmethod
def add_noise_to_image_conditioning_latents(
t: float,
init_latents: torch.Tensor,
latents: torch.Tensor,
noise_scale: float,
conditioning_mask: torch.Tensor,
generator,
eps=1e-6,
):
"""
Add timestep-dependent noise to the hard-conditioning latents. This helps with motion continuity, especially
when conditioned on a single frame.
"""
noise = randn_tensor(
latents.shape,
generator=generator,
device=latents.device,
dtype=latents.dtype,
)
# Add noise only to hard-conditioning latents (conditioning_mask = 1.0)
need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1)
noised_latents = init_latents + noise_scale * noise * (t**2)
latents = torch.where(need_to_noise, noised_latents, latents)
return latents
def prepare_latents(
self,
conditions: List[torch.Tensor],
condition_strength: List[float],
condition_frame_index: List[int],
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
width: int = 704,
num_frames: int = 161,
num_prefix_latent_frames: int = 2,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
extra_conditioning_latents = []
extra_conditioning_video_ids = []
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
condition_latents, self.vae.latents_mean, self.vae.latents_std
).to(device, dtype=dtype)
num_data_frames = data.size(2)
num_cond_frames = condition_latents.size(2)
if frame_index == 0:
latents[:, :, :num_cond_frames] = torch.lerp(
latents[:, :, :num_cond_frames], condition_latents, strength
)
condition_latent_frames_mask[:, :num_cond_frames] = strength
else:
if num_data_frames > 1:
if num_cond_frames < num_prefix_latent_frames:
raise ValueError(
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
)
if num_cond_frames > num_prefix_latent_frames:
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
latents[:, :, start_frame:end_frame] = torch.lerp(
latents[:, :, start_frame:end_frame],
condition_latents[:, :, num_prefix_latent_frames:],
strength,
)
condition_latent_frames_mask[:, start_frame:end_frame] = strength
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, strength)
condition_video_ids = self._prepare_video_ids(
batch_size,
condition_latents.size(2),
latent_height,
latent_width,
patch_size=self.transformer_spatial_patch_size,
patch_size_t=self.transformer_temporal_patch_size,
device=device,
)
condition_video_ids = self._scale_video_ids(
condition_video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=frame_index,
device=device,
)
condition_latents = self._pack_latents(
condition_latents,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
condition_conditioning_mask = torch.full(
condition_latents.shape[:2], strength, device=device, dtype=dtype
)
extra_conditioning_latents.append(condition_latents)
extra_conditioning_video_ids.append(condition_video_ids)
extra_conditioning_mask.append(condition_conditioning_mask)
extra_conditioning_num_latents += condition_latents.size(1)
video_ids = self._prepare_video_ids(
batch_size,
num_latent_frames,
latent_height,
latent_width,
patch_size_t=self.transformer_temporal_patch_size,
patch_size=self.transformer_spatial_patch_size,
device=device,
)
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
video_ids = self._scale_video_ids(
video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=0,
device=device,
)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
conditions: Union[LTXVideoCondition, List[LTXVideoCondition]] = None,
image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
video: List[PipelineImageInput] = None,
frame_index: Union[int, List[int]] = 0,
strength: Union[float, List[float]] = 1.0,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 704,
num_frames: int = 161,
frame_rate: int = 25,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
image_cond_noise_scale: float = 0.15,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
conditions (`List[LTXVideoCondition], *optional*`):
The list of frame-conditioning items for the video generation.If not provided, conditions will be
created using `image`, `video`, `frame_index` and `strength`.
image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
The image or images to condition the video generation. If not provided, one has to pass `video` or
`conditions`.
video (`List[PipelineImageInput]`, *optional*):
The video to condition the video generation. If not provided, one has to pass `image` or `conditions`.
frame_index (`int` or `List[int]`, *optional*):
The frame index or frame indices at which the image or video will conditionally effect the video
generation. If not provided, one has to pass `conditions`.
strength (`float` or `List[float]`, *optional*):
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `512`):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, defaults to `704`):
The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `161`):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `128 `):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if latents is not None:
raise ValueError("Passing latents is not yet supported.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
conditions=conditions,
image=image,
video=video,
frame_index=frame_index,
strength=strength,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if conditions is not None:
if not isinstance(conditions, list):
conditions = [conditions]
strength = [condition.strength for condition in conditions]
frame_index = [condition.frame_index for condition in conditions]
image = [condition.image for condition in conditions]
video = [condition.video for condition in conditions]
else:
if not isinstance(image, list):
image = [image]
num_conditions = 1
elif isinstance(image, list):
num_conditions = len(image)
if not isinstance(video, list):
video = [video]
num_conditions = 1
elif isinstance(video, list):
num_conditions = len(video)
if not isinstance(frame_index, list):
frame_index = [frame_index] * num_conditions
if not isinstance(strength, list):
strength = [strength] * num_conditions
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
vae_dtype = self.vae.dtype
conditioning_tensors = []
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
image, video, frame_index, strength
):
if condition_image is not None:
condition_tensor = (
self.video_processor.preprocess(condition_image, height, width)
.unsqueeze(2)
.to(device, dtype=vae_dtype)
)
elif condition_video is not None:
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
num_frames_input = condition_tensor.size(2)
num_frames_output = self.trim_conditioning_sequence(
condition_frame_index, num_frames_input, num_frames
)
condition_tensor = condition_tensor[:, :, :num_frames_output]
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
else:
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
raise ValueError(
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
f"but got {condition_tensor.size(2)} frames."
)
conditioning_tensors.append(condition_tensor)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
conditioning_tensors,
strength,
frame_index,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames=num_frames,
generator=generator,
device=device,
dtype=torch.float32,
)
video_coords = video_coords.float()
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
init_latents = latents.clone()
if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
sigmas = linear_quadratic_schedule(num_inference_steps)
timesteps = sigmas * 1000
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps=timesteps,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if image_cond_noise_scale > 0:
# Add timestep-dependent noise to the hard-conditioning latents
# This helps with motion continuity, especially when conditioned on a single frame
latents = self.add_noise_to_image_conditioning_latents(
t / 1000.0,
init_latents,
latents,
image_cond_noise_scale,
conditioning_mask,
generator,
)
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
conditioning_mask_model_input = (
torch.cat([conditioning_mask, conditioning_mask])
if self.do_classifier_free_guidance
else conditioning_mask
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
denoised_latents = self.scheduler.step(
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
)[0]
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = latents[:, extra_conditioning_num_latents:]
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
if output_type == "latent":
video = latents
else:
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)
...@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ...@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions # 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = ( rope_interpolation_scale = (
1 / latent_frame_rate, self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
) )
......
...@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
s_tmax: float = float("inf"), s_tmax: float = float("inf"),
s_noise: float = 1.0, s_noise: float = 1.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
per_token_timesteps: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
""" """
...@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Scaling factor for noise added to the sample. Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
...@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Upcast to avoid precision issues when computing prev_sample # Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32) sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index] if per_token_timesteps is not None:
sigma_next = self.sigmas[self.step_index + 1] per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
prev_sample = sample + (sigma_next - sigma) * model_output sigmas = self.sigmas[:, None, None]
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
dt = sigma_next - sigma
# Cast sample back to model compatible dtype prev_sample = sample + dt * model_output
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one # upon completion increase step index by one
self._step_index += 1 self._step_index += 1
if per_token_timesteps is None:
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject): ...@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXImageToVideoPipeline(metaclass=DummyObject): class LTXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXConditionPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTXConditionPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
out_channels=8,
patch_size=1,
patch_size_t=1,
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
caption_channels=32,
)
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
if use_conditions:
conditions = LTXVideoCondition(
image=image,
)
else:
conditions = None
inputs = {
"conditions": conditions,
"image": None if use_conditions else image,
"prompt": "dance monkey",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
# 8 * k + 1 is the recommendation
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
video = pipe(**inputs).frames
generated_video = video[0]
video2 = pipe(**inputs2).frames
generated_video2 = video2[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
max_diff = np.abs(generated_video - generated_video2).max()
self.assertLessEqual(max_diff, 1e-3)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment