Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
8bf5c2e1
Commit
8bf5c2e1
authored
Jul 09, 2025
by
wangshankun
Browse files
Lora format convert
parent
398b598a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
138 additions
and
0 deletions
+138
-0
tools/extract/convert_vigen_to_x2v_lora.py
tools/extract/convert_vigen_to_x2v_lora.py
+138
-0
No files found.
tools/extract/convert_vigen_to_x2v_lora.py
0 → 100644
View file @
8bf5c2e1
### Using this script to convert ViGen-DiT Lora Format to Lightx2v
###
### Cmd line:python convert_vigen_to_x2v_lora.py model_lora.pt model_lora_converted.safetensors
###
### ViGen-DiT Project Url: https://github.com/yl-1993/ViGen-DiT
###
import
torch
from
safetensors.torch
import
save_file
import
sys
import
os
if
len
(
sys
.
argv
)
!=
3
:
print
(
"用法: python convert_lora.py <输入文件.pt> <输出文件.safetensors>"
)
sys
.
exit
(
1
)
ckpt_path
=
sys
.
argv
[
1
]
output_path
=
sys
.
argv
[
2
]
if
not
os
.
path
.
exists
(
ckpt_path
):
print
(
f
"❌ 输入文件不存在:
{
ckpt_path
}
"
)
sys
.
exit
(
1
)
state_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cpu"
)
if
"state_dict"
in
state_dict
:
state_dict
=
state_dict
[
"state_dict"
]
elif
"model"
in
state_dict
:
state_dict
=
state_dict
[
"model"
]
mapped_dict
=
{}
# 映射表定义
attn_map
=
{
"attn1"
:
"self_attn"
,
"attn2"
:
"cross_attn"
,
}
proj_map
=
{
"to_q"
:
"q"
,
"to_k"
:
"k"
,
"to_v"
:
"v"
,
"to_out"
:
"o"
,
}
lora_map
=
{
"lora_A"
:
"lora_down"
,
"lora_B"
:
"lora_up"
,
}
for
k
,
v
in
state_dict
.
items
():
# 预处理:将 to_out.0 / to_out.1 统一替换为 to_out
k
=
k
.
replace
(
"to_out.0"
,
"to_out"
).
replace
(
"to_out.1"
,
"to_out"
)
k
=
k
.
replace
(
".default"
,
""
)
# 去除.default
parts
=
k
.
split
(
"."
)
# === Attention Blocks ===
if
k
.
startswith
(
"blocks."
)
and
len
(
parts
)
>=
5
:
block_id
=
parts
[
1
]
if
parts
[
2
].
startswith
(
"attn"
):
attn_raw
=
parts
[
2
]
proj_raw
=
parts
[
3
]
lora_raw
=
parts
[
4
]
if
attn_raw
in
attn_map
and
proj_raw
in
proj_map
and
lora_raw
in
lora_map
:
attn_name
=
attn_map
[
attn_raw
]
proj_name
=
proj_map
[
proj_raw
]
lora_name
=
lora_map
[
lora_raw
]
new_k
=
f
"diffusion_model.blocks.
{
block_id
}
.
{
attn_name
}
.
{
proj_name
}
.
{
lora_name
}
.weight"
mapped_dict
[
new_k
]
=
v
continue
else
:
print
(
f
"无法映射 attention key:
{
k
}
"
)
continue
# === FFN Blocks ===
elif
parts
[
2
]
==
"ffn"
:
if
parts
[
3
:
6
]
==
[
"net"
,
"0"
,
"proj"
]:
layer_id
=
"0"
lora_raw
=
parts
[
6
]
elif
parts
[
3
:
5
]
==
[
"net"
,
"2"
]:
layer_id
=
"2"
lora_raw
=
parts
[
5
]
else
:
print
(
f
"无法解析 FFN key:
{
k
}
"
)
continue
if
lora_raw
not
in
lora_map
:
print
(
f
"未知 FFN LoRA 类型:
{
k
}
"
)
continue
lora_name
=
lora_map
[
lora_raw
]
new_k
=
f
"diffusion_model.blocks.
{
block_id
}
.ffn.
{
layer_id
}
.
{
lora_name
}
.weight"
mapped_dict
[
new_k
]
=
v
continue
# === Text Embedding ===
elif
k
.
startswith
(
"condition_embedder.text_embedder.linear_"
):
layer_id
=
parts
[
2
].
split
(
"_"
)[
1
]
lora_raw
=
parts
[
3
]
if
lora_raw
in
lora_map
:
lora_name
=
lora_map
[
lora_raw
]
new_k
=
f
"diffusion_model.text_embedding.
{
layer_id
}
.
{
lora_name
}
.weight"
mapped_dict
[
new_k
]
=
v
continue
else
:
print
(
f
"text_embedder 未知 LoRA 类型:
{
k
}
"
)
continue
'''
# === Time Embedding ===
elif k.startswith("condition_embedder.time_embedder.linear_"):
layer_id = parts[2].split("_")[1]
lora_raw = parts[3]
if lora_raw in lora_map:
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.time_embedding.{layer_id}.{lora_name}.weight"
mapped_dict[new_k] = v
continue
else:
print(f"time_embedder 未知 LoRA 类型: {k}")
continue
# === Time Projection ===
elif k.startswith("condition_embedder.time_proj."):
lora_raw = parts[2]
if lora_raw in lora_map:
lora_name = lora_map[lora_raw]
new_k = f"diffusion_model.time_projection.1.{lora_name}.weight"
mapped_dict[new_k] = v
continue
else:
print(f"time_proj 未知 LoRA 类型: {k}")
continue
'''
# fallback
print
(
f
"未识别结构 key:
{
k
}
"
)
# 保存
print
(
f
"
\n
✅ 成功重命名
{
len
(
mapped_dict
)
}
个 LoRA 参数"
)
save_file
(
mapped_dict
,
output_path
)
print
(
f
"💾 已保存为:
{
output_path
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment