Commit 8bf5c2e1 authored by wangshankun's avatar wangshankun
Browse files

Lora format convert

parent 398b598a
### Using this script to convert ViGen-DiT Lora Format to Lightx2v
###
### Cmd line:python convert_vigen_to_x2v_lora.py model_lora.pt model_lora_converted.safetensors
###
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
###
import torch
from safetensors.torch import save_file
import sys
import os
if len(sys.argv) != 3:
print("用法: python convert_lora.py <输入文件.pt> <输出文件.safetensors>")
sys.exit(1)
ckpt_path = sys.argv[1]
output_path = sys.argv[2]
if not os.path.exists(ckpt_path):
print(f"❌ 输入文件不存在: {ckpt_path}")
sys.exit(1)
state_dict = torch.load(ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
elif "model" in state_dict:
state_dict = state_dict["model"]
mapped_dict = {}
# 映射表定义
attn_map = {
"attn1": "self_attn",
"attn2": "cross_attn",
}
proj_map = {
"to_q": "q",
"to_k": "k",
"to_v": "v",
"to_out": "o",
}
lora_map = {
"lora_A": "lora_down",
"lora_B": "lora_up",
}
for k, v in state_dict.items():
# 预处理:将 to_out.0 / to_out.1 统一替换为 to_out
k = k.replace("to_out.0", "to_out").replace("to_out.1", "to_out")
k = k.replace(".default", "") # 去除.default
parts = k.split(".")
# === Attention Blocks ===
if k.startswith("blocks.") and len(parts) >= 5:
block_id = parts[1]
if parts[2].startswith("attn"):
attn_raw = parts[2]
proj_raw = parts[3]
lora_raw = parts[4]
if attn_raw in attn_map and proj_raw in proj_map and lora_raw in lora_map:
attn_name = attn_map[attn_raw]
proj_name = proj_map[proj_raw]
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.blocks.{block_id}.{attn_name}.{proj_name}.{lora_name}.weight"
mapped_dict[new_k] = v
continue
else:
print(f"无法映射 attention key: {k}")
continue
# === FFN Blocks ===
elif parts[2] == "ffn":
if parts[3:6] == ["net", "0", "proj"]:
layer_id = "0"
lora_raw = parts[6]
elif parts[3:5] == ["net", "2"]:
layer_id = "2"
lora_raw = parts[5]
else:
print(f"无法解析 FFN key: {k}")
continue
if lora_raw not in lora_map:
print(f"未知 FFN LoRA 类型: {k}")
continue
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.blocks.{block_id}.ffn.{layer_id}.{lora_name}.weight"
mapped_dict[new_k] = v
continue
# === Text Embedding ===
elif k.startswith("condition_embedder.text_embedder.linear_"):
layer_id = parts[2].split("_")[1]
lora_raw = parts[3]
if lora_raw in lora_map:
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.text_embedding.{layer_id}.{lora_name}.weight"
mapped_dict[new_k] = v
continue
else:
print(f"text_embedder 未知 LoRA 类型: {k}")
continue
'''
# === Time Embedding ===
elif k.startswith("condition_embedder.time_embedder.linear_"):
layer_id = parts[2].split("_")[1]
lora_raw = parts[3]
if lora_raw in lora_map:
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.time_embedding.{layer_id}.{lora_name}.weight"
mapped_dict[new_k] = v
continue
else:
print(f"time_embedder 未知 LoRA 类型: {k}")
continue
# === Time Projection ===
elif k.startswith("condition_embedder.time_proj."):
lora_raw = parts[2]
if lora_raw in lora_map:
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.time_projection.1.{lora_name}.weight"
mapped_dict[new_k] = v
continue
else:
print(f"time_proj 未知 LoRA 类型: {k}")
continue
'''
# fallback
print(f"未识别结构 key: {k}")
# 保存
print(f"\n✅ 成功重命名 {len(mapped_dict)} 个 LoRA 参数")
save_file(mapped_dict, output_path)
print(f"💾 已保存为: {output_path}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment