Unverified Commit b8093e66 authored by hlky's avatar hlky Committed by GitHub
Browse files

Fix LTX 0.9.5 single file (#11271)

parent e121d0ef
...@@ -177,6 +177,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = { ...@@ -177,6 +177,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
...@@ -638,7 +639,9 @@ def infer_diffusers_model_type(checkpoint): ...@@ -638,7 +639,9 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell" model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
model_type = "ltx-video-0.9.5"
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1" model_type = "ltx-video-0.9.1"
else: else:
model_type = "ltx-video" model_type = "ltx-video"
...@@ -2403,13 +2406,41 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): ...@@ -2403,13 +2406,41 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"last_scale_shift_table": "scale_shift_table", "last_scale_shift_table": "scale_shift_table",
} }
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = { VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_, "per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_,
} }
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
for key in list(converted_state_dict.keys()): for key in list(converted_state_dict.keys()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment