Unverified Commit 2e83cbbb authored by Aryan's avatar Aryan Committed by GitHub
Browse files

LTX 0.9.5 (#10968)



* update


---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 33d10af2
...@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24) ...@@ -196,6 +196,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all - all
- __call__ - __call__
## LTXConditionPipeline
[[autodoc]] LTXConditionPipeline
- all
- __call__
## LTXPipelineOutput ## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput [[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
...@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = { ...@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
"last_scale_shift_table": "scale_shift_table", "last_scale_shift_table": "scale_shift_table",
} }
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = { VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_, "per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_,
...@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = { ...@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
"model.diffusion_model": remove_keys_, "model.diffusion_model": remove_keys_,
} }
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict state_dict = saved_dict
...@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: ...@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer( def convert_transformer(
ckpt_path: str, ckpt_path: str,
dtype: torch.dtype, dtype: torch.dtype,
version: str = "0.9.0",
): ):
PREFIX_KEY = "model.diffusion_model." PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path)) original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights(): with init_empty_weights():
transformer = LTXVideoTransformer3DModel() transformer = LTXVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()): for key in list(original_state_dict.keys()):
new_key = key[:] new_key = key[:]
...@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: ...@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3, "out_channels": 3,
"latent_channels": 128, "latent_channels": 128,
"block_out_channels": (128, 256, 512, 512), "block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (128, 256, 512, 512), "decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4), "layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4), "decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False), "spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False), "decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False), "decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (False, False, False, False), "upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1), "upsample_factor": (1, 1, 1, 1),
"patch_size": 4, "patch_size": 4,
...@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]: ...@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3, "out_channels": 3,
"latent_channels": 128, "latent_channels": 128,
"block_out_channels": (128, 256, 512, 512), "block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024), "decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4), "layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8), "decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False), "spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True), "decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False), "decoder_inject_noise": (True, True, True, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (True, True, True), "upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2), "upsample_factor": (2, 2, 2),
"timestep_conditioning": True, "timestep_conditioning": True,
...@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]: ...@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False, "decoder_causal": False,
} }
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) elif version == "0.9.5":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config return config
...@@ -223,7 +294,7 @@ def get_args(): ...@@ -223,7 +294,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument( parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
) )
return parser.parse_args() return parser.parse_args()
...@@ -277,14 +348,17 @@ if __name__ == "__main__": ...@@ -277,14 +348,17 @@ if __name__ == "__main__":
for param in text_encoder.parameters(): for param in text_encoder.parameters():
param.data = param.data.contiguous() param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler( if args.version == "0.9.5":
use_dynamic_shifting=True, scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
base_shift=0.95, else:
max_shift=2.05, scheduler = FlowMatchEulerDiscreteScheduler(
base_image_seq_len=1024, use_dynamic_shifting=True,
max_image_seq_len=4096, base_shift=0.95,
shift_terminal=0.1, max_shift=2.05,
) base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline( pipe = LTXPipeline(
scheduler=scheduler, scheduler=scheduler,
......
...@@ -402,6 +402,7 @@ else: ...@@ -402,6 +402,7 @@ else:
"LDMTextToImagePipeline", "LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL", "LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline", "LTXImageToVideoPipeline",
"LTXPipeline", "LTXPipeline",
"Lumina2Pipeline", "Lumina2Pipeline",
...@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -947,6 +948,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline, LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL, LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline, LTXImageToVideoPipeline,
LTXPipeline, LTXPipeline,
Lumina2Pipeline, Lumina2Pipeline,
......
...@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module): ...@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
return hidden_states return hidden_states
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
class LTXVideoUpsampler3d(nn.Module): class LTXVideoUpsampler3d(nn.Module):
def __init__( def __init__(
self, self,
...@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module): ...@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
is_causal: bool = True, is_causal: bool = True,
residual: bool = False, residual: bool = False,
upscale_factor: int = 1, upscale_factor: int = 1,
padding_mode: str = "zeros",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module): ...@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
kernel_size=3, kernel_size=3,
stride=1, stride=1,
is_causal=is_causal, is_causal=is_causal,
padding_mode=padding_mode,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module): ...@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
return hidden_states return hidden_states
class LTXVideo095DownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d # Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXVideoMidBlock3d(nn.Module): class LTXVideoMidBlock3d(nn.Module):
r""" r"""
...@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module): ...@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
in_channels: int = 3, in_channels: int = 3,
out_channels: int = 128, out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4, patch_size: int = 4,
patch_size_t: int = 1, patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6, resnet_norm_eps: float = 1e-6,
...@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module): ...@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
) )
# down blocks # down blocks
num_block_out_channels = len(block_out_channels) is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels): for i in range(num_block_out_channels):
input_channel = output_channel input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] if not is_ltx_095:
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
down_block = LTXVideoDownBlock3D( else:
in_channels=input_channel, output_channel = block_out_channels[i + 1]
out_channels=output_channel,
num_layers=layers_per_block[i], if down_block_types[i] == "LTXVideoDownBlock3D":
resnet_eps=resnet_norm_eps, down_block = LTXVideoDownBlock3D(
spatio_temporal_scale=spatio_temporal_scaling[i], in_channels=input_channel,
is_causal=is_causal, out_channels=output_channel,
) num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
elif down_block_types[i] == "LTXVideo095DownBlock3D":
down_block = LTXVideo095DownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
...@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module): ...@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
# timestep embedding # timestep embedding
self.time_embedder = None self.time_embedder = None
self.scale_shift_table = None self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning: if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
...@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module): ...@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states) hidden_states = self.conv_in(hidden_states)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
...@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3, out_channels: int = 3,
latent_channels: int = 128, latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False), upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False, timestep_conditioning: bool = False,
...@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
scaling_factor: float = 1.0, scaling_factor: float = 1.0,
encoder_causal: bool = True, encoder_causal: bool = True,
decoder_causal: bool = False, decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
in_channels=in_channels, in_channels=in_channels,
out_channels=latent_channels, out_channels=latent_channels,
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
down_block_types=down_block_types,
spatio_temporal_scaling=spatio_temporal_scaling, spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
downsample_type=downsample_type,
patch_size=patch_size, patch_size=patch_size,
patch_size_t=patch_size_t, patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps, resnet_norm_eps=resnet_norm_eps,
...@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.register_buffer("latents_mean", latents_mean, persistent=True) self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True) self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) self.spatial_compression_ratio = (
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time. # to perform decoding of a single video latent at a time.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module): ...@@ -113,20 +113,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
self.patch_size_t = patch_size_t self.patch_size_t = patch_size_t
self.theta = theta self.theta = theta
def forward( def _prepare_video_coords(
self, self,
hidden_states: torch.Tensor, batch_size: int,
num_frames: int, num_frames: int,
height: int, height: int,
width: int, width: int,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, rope_interpolation_scale: Tuple[torch.Tensor, float, float],
) -> Tuple[torch.Tensor, torch.Tensor]: device: torch.device,
batch_size = hidden_states.size(0) ) -> torch.Tensor:
# Always compute rope in fp32 # Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0) grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
...@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module): ...@@ -138,6 +137,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
grid = grid.flatten(2, 4).transpose(1, 2) grid = grid.flatten(2, 4).transpose(1, 2)
return grid
def forward(
self,
hidden_states: torch.Tensor,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
video_coords: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
if video_coords is None:
grid = self._prepare_video_coords(
batch_size,
num_frames,
height,
width,
rope_interpolation_scale=rope_interpolation_scale,
device=hidden_states.device,
)
else:
grid = torch.stack(
[
video_coords[:, 0] / self.base_num_frames,
video_coords[:, 1] / self.base_height,
video_coords[:, 2] / self.base_width,
],
dim=-1,
)
start = 1.0 start = 1.0
end = self.theta end = self.theta
freqs = self.theta ** torch.linspace( freqs = self.theta ** torch.linspace(
...@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -367,10 +398,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor, timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor, encoder_attention_mask: torch.Tensor,
num_frames: int, num_frames: Optional[int] = None,
height: int, height: Optional[int] = None,
width: int, width: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
video_coords: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -389,7 +421,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
) )
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask # convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
......
...@@ -264,7 +264,7 @@ else: ...@@ -264,7 +264,7 @@ else:
] ]
) )
_import_structure["latte"] = ["LattePipeline"] _import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend( _import_structure["marigold"].extend(
...@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL, LEditsPPPipelineStableDiffusionXL,
) )
from .ltx import LTXImageToVideoPipeline, LTXPipeline from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import ( from .marigold import (
......
...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: ...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"] _import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .pipeline_ltx import LTXPipeline from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline from .pipeline_ltx_image2video import LTXImageToVideoPipeline
else: else:
......
...@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ...@@ -694,9 +694,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions # 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = ( rope_interpolation_scale = (
1 / latent_frame_rate, self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
) )
......
This diff is collapsed.
...@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ...@@ -764,9 +764,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions # 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = ( rope_interpolation_scale = (
1 / latent_frame_rate, self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio, self.vae_spatial_compression_ratio,
) )
......
...@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -377,6 +377,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
s_tmax: float = float("inf"), s_tmax: float = float("inf"),
s_noise: float = 1.0, s_noise: float = 1.0,
generator: Optional[torch.Generator] = None, generator: Optional[torch.Generator] = None,
per_token_timesteps: Optional[torch.Tensor] = None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
""" """
...@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -397,6 +398,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Scaling factor for noise added to the sample. Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*): generator (`torch.Generator`, *optional*):
A random number generator. A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`): return_dict (`bool`):
Whether or not to return a Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple.
...@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -427,16 +430,26 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
# Upcast to avoid precision issues when computing prev_sample # Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32) sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index] if per_token_timesteps is not None:
sigma_next = self.sigmas[self.step_index + 1] per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps
prev_sample = sample + (sigma_next - sigma) * model_output sigmas = self.sigmas[:, None, None]
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
dt = sigma_next - sigma
# Cast sample back to model compatible dtype prev_sample = sample + dt * model_output
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one # upon completion increase step index by one
self._step_index += 1 self._step_index += 1
if per_token_timesteps is None:
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
if not return_dict: if not return_dict:
return (prev_sample,) return (prev_sample,)
......
...@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject): ...@@ -1217,6 +1217,21 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXImageToVideoPipeline(metaclass=DummyObject): class LTXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXConditionPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class LTXConditionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTXConditionPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
out_channels=8,
patch_size=1,
patch_size_t=1,
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
caption_channels=32,
)
torch.manual_seed(0)
vae = AutoencoderKLLTXVideo(
in_channels=3,
out_channels=3,
latent_channels=8,
block_out_channels=(8, 8, 8, 8),
decoder_block_out_channels=(8, 8, 8, 8),
layers_per_block=(1, 1, 1, 1, 1),
decoder_layers_per_block=(1, 1, 1, 1, 1),
spatio_temporal_scaling=(True, True, False, False),
decoder_spatio_temporal_scaling=(True, True, False, False),
decoder_inject_noise=(False, False, False, False, False),
upsample_residual=(False, False, False, False),
upsample_factor=(1, 1, 1, 1),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0, use_conditions=False):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
if use_conditions:
conditions = LTXVideoCondition(
image=image,
)
else:
conditions = None
inputs = {
"conditions": conditions,
"image": None if use_conditions else image,
"prompt": "dance monkey",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"height": 32,
"width": 32,
# 8 * k + 1 is the recommendation
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs2 = self.get_dummy_inputs(device, use_conditions=True)
video = pipe(**inputs).frames
generated_video = video[0]
video2 = pipe(**inputs2).frames
generated_video2 = video2[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
max_diff = np.abs(generated_video - generated_video2).max()
self.assertLessEqual(max_diff, 1e-3)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment