Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c7eb4631
Commit
c7eb4631
authored
Aug 07, 2025
by
wangshankun
Browse files
Bug Fix: Fix incomplete parallel loading of audio model
parent
dd958c79
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
92 additions
and
15 deletions
+92
-15
configs/audio_driven/wan_i2v_audio.json
configs/audio_driven/wan_i2v_audio.json
+2
-1
configs/audio_driven/wan_i2v_audio_dist.json
configs/audio_driven/wan_i2v_audio_dist.json
+23
-0
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+31
-8
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+0
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-0
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+3
-3
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+2
-1
scripts/wan/run_wan_i2v_audio_dist.sh
scripts/wan/run_wan_i2v_audio_dist.sh
+30
-0
No files found.
configs/audio_driven/wan_i2v_audio.json
View file @
c7eb4631
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
1
6
,
"video_duration"
:
1
2
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
720
,
...
...
@@ -14,5 +14,6 @@
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"adaptive_resize"
:
true
,
"use_31_block"
:
false
}
configs/audio_driven/wan_i2v_audio_dist.json
0 → 100644
View file @
c7eb4631
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
12
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"adaptive_resize"
:
true
,
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
}
}
lightx2v/models/networks/wan/audio_adapter.py
View file @
c7eb4631
...
...
@@ -7,14 +7,12 @@ import os
import
safetensors
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
transformers
import
AutoModel
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
...
...
@@ -54,15 +52,40 @@ def load_pt_safetensors(in_path: str):
return
state_dict
def
rank0_load_state_dict_from_path
(
model
,
in_path
:
str
,
strict
:
bool
=
True
):
import
torch.distributed
as
dist
def
rank0_load_state_dict_from_path
(
model
,
in_path
:
str
,
strict
:
bool
=
True
,
seq_p_group
=
None
):
model
=
model
.
to
(
"cuda"
)
# 确定当前进程是否是(负责加载权重)
is_leader
=
False
if
seq_p_group
is
not
None
and
dist
.
is_initialized
():
group_rank
=
dist
.
get_rank
(
group
=
seq_p_group
)
if
group_rank
==
0
:
is_leader
=
True
elif
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
is_leader
=
True
if
(
dist
.
is_initialized
()
and
dist
.
get_rank
()
==
0
)
or
(
not
dist
.
is_initialized
())
:
if
is_leader
:
state_dict
=
load_pt_safetensors
(
in_path
)
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
if
dist
.
is_initialized
():
# 将模型状态从领导者同步到组内所有其他进程
if
seq_p_group
is
not
None
and
dist
.
is_initialized
():
dist
.
barrier
(
group
=
seq_p_group
)
src_global_rank
=
dist
.
get_process_group_ranks
(
seq_p_group
)[
0
]
for
param
in
model
.
parameters
():
dist
.
broadcast
(
param
.
data
,
src
=
src_global_rank
,
group
=
seq_p_group
)
for
buffer
in
model
.
buffers
():
dist
.
broadcast
(
buffer
.
data
,
src
=
src_global_rank
,
group
=
seq_p_group
)
elif
dist
.
is_initialized
():
dist
.
barrier
()
return
model
.
to
(
dtype
=
GET_DTYPE
(),
device
=
"cuda"
)
for
param
in
model
.
parameters
():
dist
.
broadcast
(
param
.
data
,
src
=
0
)
for
buffer
in
model
.
buffers
():
dist
.
broadcast
(
buffer
.
data
,
src
=
0
)
return
model
.
to
(
dtype
=
GET_DTYPE
())
def
linear_interpolation
(
features
,
output_len
:
int
):
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
c7eb4631
...
...
@@ -13,8 +13,6 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights
,
)
from
loguru
import
logger
class
WanAudioModel
(
WanModel
):
pre_weight_class
=
WanPreWeights
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
c7eb4631
...
...
@@ -11,6 +11,7 @@ from lightx2v.utils.envs import *
from
.utils
import
apply_rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_audio
class
WanTransformerInfer
(
BaseTransformerInfer
):
def
__init__
(
self
,
config
):
self
.
config
=
config
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
c7eb4631
...
...
@@ -27,9 +27,9 @@ def compute_freqs_audio(c, grid_sizes, freqs):
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
# 时间(帧)编码
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
# 空间(高度)编码
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
# 空间(宽度)编码
],
dim
=-
1
,
).
reshape
(
seq_len
,
1
,
-
1
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
c7eb4631
...
...
@@ -417,7 +417,6 @@ class WanAudioRunner(WanRunner): # type:ignore
time_freq_dim
=
256
,
projection_transformer_layers
=
4
,
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
# Audio encoder
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
...
...
@@ -432,6 +431,8 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
seq_p_group
=
None
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
,
seq_p_group
=
seq_p_group
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
(),
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
,
seq_p_group
=
seq_p_group
)
...
...
scripts/wan/run_wan_i2v_audio_dist.sh
0 → 100755
View file @
c7eb4631
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# set environment variables
source
${
lightx2v_path
}
/scripts/base/base.sh
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3
export
TORCH_CUDA_ARCH_LIST
=
"9.0"
export
PYTORCH_CUDA_ALLOC_CONF
=
expandable_segments:True
export
ENABLE_GRAPH_MODE
=
false
#for debugging
#export TORCH_NCCL_BLOCKING_WAIT=1 #启用 NCCL 阻塞等待模式(否则 watchdog 会杀死卡顿的进程)
#export NCCL_BLOCKING_WAIT_TIMEOUT=1800 #设置 watchdog 的等待超时
torchrun
--nproc-per-node
4
-m
lightx2v.infer
\
--model_cls
wan2.1_audio
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/audio_driven/wan_i2v_audio_dist.json
\
--prompt
"The video features a old lady is saying something and knitting a sweater."
\
--negative_prompt
色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走
\
--image_path
${
lightx2v_path
}
/assets/inputs/audio/15.png
\
--audio_path
${
lightx2v_path
}
/assets/inputs/audio/15.wav
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_audio.mp4
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment