Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
daa06243
Commit
daa06243
authored
Aug 09, 2025
by
wangshankun
Browse files
BugFix:1.cfg并行和模型加载group冲突2.offload和广播功能冲突3.savevidoe并行中多次save
parent
64948a2e
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
88 additions
and
37 deletions
+88
-37
configs/audio_driven/wan_i2v_audio_dist_offload.json
configs/audio_driven/wan_i2v_audio_dist_offload.json
+28
-0
configs/audio_driven/wan_i2v_audio_offload.json
configs/audio_driven/wan_i2v_audio_offload.json
+24
-0
lightx2v/models/input_encoders/hf/t5/model.py
lightx2v/models/input_encoders/hf/t5/model.py
+1
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+1
-2
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+9
-12
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+12
-9
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+5
-3
lightx2v/models/video_encoders/hf/wan/vae.py
lightx2v/models/video_encoders/hf/wan/vae.py
+1
-3
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+7
-7
No files found.
configs/audio_driven/wan_i2v_audio_dist_offload.json
0 → 100644
View file @
daa06243
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
5
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"use_31_block"
:
false
,
"adaptive_resize"
:
true
,
"parallel"
:
{
"seq_p_size"
:
4
,
"seq_p_attn_type"
:
"ulysses"
},
"cpu_offload"
:
true
,
"offload_granularity"
:
"block"
,
"t5_cpu_offload"
:
true
,
"offload_ratio_val"
:
1
,
"t5_offload_granularity"
:
"block"
,
"use_tiling_vae"
:
true
}
configs/audio_driven/wan_i2v_audio_offload.json
0 → 100644
View file @
daa06243
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
5
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"adaptive_resize"
:
true
,
"use_31_block"
:
false
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"block"
,
"t5_cpu_offload"
:
true
,
"offload_ratio_val"
:
1
,
"t5_offload_granularity"
:
"block"
,
"use_tiling_vae"
:
true
}
lightx2v/models/input_encoders/hf/t5/model.py
View file @
daa06243
...
...
@@ -571,7 +571,7 @@ class T5EncoderModel:
.
requires_grad_
(
False
)
)
weights_ditc
=
load_weights_distributed
(
self
.
checkpoint_path
,
seq_p_group
)
weights_ditc
=
load_weights_distributed
(
self
.
checkpoint_path
)
model
.
load_state_dict
(
weights_ditc
)
self
.
model
=
model
...
...
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
daa06243
...
...
@@ -434,8 +434,7 @@ class CLIPModel:
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
dtype
=
dtype
,
device
=
device
,
quantized
=
self
.
quantized
,
quant_scheme
=
quant_scheme
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
weight_dict
=
load_weights_distributed
(
self
.
checkpoint_path
,
seq_p_group
=
self
.
seq_p_group
)
# weight_dict = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True)
weight_dict
=
load_weights_distributed
(
self
.
checkpoint_path
)
keys
=
list
(
weight_dict
.
keys
())
for
key
in
keys
:
...
...
lightx2v/models/networks/wan/audio_adapter.py
View file @
daa06243
...
...
@@ -53,13 +53,13 @@ def load_pt_safetensors(in_path: str):
return
state_dict
def
rank0_load_state_dict_from_path
(
model
,
in_path
:
str
,
strict
:
bool
=
True
,
seq_p_group
=
None
):
def
rank0_load_state_dict_from_path
(
model
,
in_path
:
str
,
strict
:
bool
=
True
):
model
=
model
.
to
(
"cuda"
)
# 确定当前进程是否是(负责加载权重)
is_leader
=
False
if
seq_p_group
is
not
None
and
dist
.
is_initialized
():
group
_rank
=
dist
.
get_rank
(
group
=
seq_p_group
)
if
group
_rank
==
0
:
if
dist
.
is_initialized
():
current
_rank
=
dist
.
get_rank
()
if
current
_rank
==
0
:
is_leader
=
True
elif
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
is_leader
=
True
...
...
@@ -70,16 +70,13 @@ def rank0_load_state_dict_from_path(model, in_path: str, strict: bool = True, se
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
# 将模型状态从领导者同步到组内所有其他进程
if
seq_p_group
is
not
None
and
dist
.
is_initialized
():
dist
.
barrier
(
group
=
seq_p_group
,
device_ids
=
[
torch
.
cuda
.
current_device
()])
src_global_rank
=
dist
.
get_process_group_ranks
(
seq_p_group
)[
0
]
if
dist
.
is_initialized
():
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
src_global_rank
=
0
for
param
in
model
.
parameters
():
dist
.
broadcast
(
param
.
data
,
src
=
src_global_rank
,
group
=
seq_p_group
)
dist
.
broadcast
(
param
.
data
,
src
=
src_global_rank
)
for
buffer
in
model
.
buffers
():
dist
.
broadcast
(
buffer
.
data
,
src
=
src_global_rank
,
group
=
seq_p_group
)
dist
.
broadcast
(
buffer
.
data
,
src
=
src_global_rank
)
elif
dist
.
is_initialized
():
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
for
param
in
model
.
parameters
():
...
...
lightx2v/models/networks/wan/model.py
View file @
daa06243
...
...
@@ -191,11 +191,11 @@ class WanModel:
if
weight_dict
is
None
:
is_weight_loader
=
False
if
self
.
seq_p_group
is
None
:
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
is_weight_loader
=
True
logger
.
info
(
f
"Loading original dit model from
{
self
.
model_path
}
"
)
elif
dist
.
is_initialized
():
if
dist
.
get_rank
(
group
=
self
.
seq_p_group
)
==
0
:
if
dist
.
get_rank
()
==
0
:
is_weight_loader
=
True
logger
.
info
(
f
"Loading original dit model from
{
self
.
model_path
}
"
)
...
...
@@ -209,13 +209,13 @@ class WanModel:
else
:
cpu_weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
seq_p_group
is
None
:
# 单卡模式
if
self
.
config
.
get
(
"device_mesh"
)
is
None
:
# 单卡模式
self
.
original_weight_dict
=
{}
init_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
for
key
,
tensor
in
cpu_weight_dict
.
items
():
self
.
original_weight_dict
[
key
]
=
tensor
.
to
(
"cuda"
,
non_blocking
=
True
)
self
.
original_weight_dict
[
key
]
=
tensor
.
to
(
init_device
,
non_blocking
=
True
)
else
:
seq_p_group
=
self
.
seq_p_group
global_src_rank
=
dist
.
get_process_group_ranks
(
seq_p_group
)[
0
]
global_src_rank
=
0
meta_dict
=
{}
if
is_weight_loader
:
...
...
@@ -223,20 +223,20 @@ class WanModel:
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
obj_list
=
[
meta_dict
]
if
is_weight_loader
else
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
,
group
=
seq_p_group
)
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
self
.
original_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
self
.
original_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
"cuda"
)
dist
.
barrier
(
group
=
seq_p_group
,
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
tensor_to_broadcast
=
self
.
original_weight_dict
[
key
]
if
is_weight_loader
:
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
global_src_rank
,
group
=
seq_p_group
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
global_src_rank
)
if
is_weight_loader
:
del
cpu_weight_dict
...
...
@@ -252,6 +252,9 @@ class WanModel:
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
del
self
.
original_weight_dict
torch
.
cuda
.
empty_cache
()
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
daa06243
...
...
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torchaudio
as
ta
from
PIL
import
Image
from
einops
import
rearrange
...
...
@@ -432,7 +433,7 @@ class WanAudioRunner(WanRunner): # type:ignore
else
:
seq_p_group
=
None
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
,
seq_p_group
=
seq_p_group
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
GET_DTYPE
(),
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
,
seq_p_group
=
seq_p_group
...
...
@@ -564,6 +565,7 @@ class WanAudioRunner(WanRunner): # type:ignore
comfyui_audio
=
{
"waveform"
:
audio_waveform
,
"sample_rate"
:
self
.
_audio_processor
.
audio_sr
}
# Save video if requested
if
(
self
.
config
.
get
(
"device_mesh"
)
is
not
None
and
dist
.
get_rank
()
==
0
)
or
self
.
config
.
get
(
"device_mesh"
)
is
None
:
if
save_video
and
self
.
config
.
get
(
"save_video_path"
,
None
):
self
.
_save_video_with_audio
(
comfyui_images
,
merge_audio
,
target_fps
)
...
...
lightx2v/models/video_encoders/hf/wan/vae.py
View file @
daa06243
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
logging
import
torch
import
torch.distributed
as
dist
...
...
@@ -781,8 +780,7 @@ def _video_vae(pretrained_path=None, z_dim=None, device="cpu", seq_p_group=None,
model
=
WanVAE_
(
**
cfg
)
# load checkpoint
logging
.
info
(
f
"loading
{
pretrained_path
}
"
)
weights_dict
=
load_weights_distributed
(
pretrained_path
,
seq_p_group
)
weights_dict
=
load_weights_distributed
(
pretrained_path
)
model
.
load_state_dict
(
weights_dict
,
assign
=
True
)
...
...
lightx2v/utils/utils.py
View file @
daa06243
...
...
@@ -324,13 +324,13 @@ def find_gguf_model_path(config, ckpt_config_key=None, subdir=None):
raise
FileNotFoundError
(
f
"No GGUF model files (.gguf) found.
\n
Please download the model from: https://huggingface.co/lightx2v/ or specify the model path in the configuration file."
)
def
load_weights_distributed
(
checkpoint_path
,
seq_p_group
=
None
):
if
seq_p_group
is
None
or
not
dist
.
is_initialized
():
def
load_weights_distributed
(
checkpoint_path
):
if
not
dist
.
is_initialized
():
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
return
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
is_leader
=
False
current_rank
=
dist
.
get_rank
(
seq_p_group
)
current_rank
=
dist
.
get_rank
()
if
current_rank
==
0
:
is_leader
=
True
...
...
@@ -348,22 +348,22 @@ def load_weights_distributed(checkpoint_path, seq_p_group=None):
obj_list
=
[
meta_dict
]
if
is_leader
else
[
None
]
# 获取rank0的全局 rank 用于广播
src_global_rank
=
dist
.
get_process_group_ranks
(
seq_p_group
)[
0
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
src_global_rank
,
group
=
seq_p_group
)
src_global_rank
=
0
dist
.
broadcast_object_list
(
obj_list
,
src
=
src_global_rank
)
synced_meta_dict
=
obj_list
[
0
]
# 所有进程所在的GPU上创建空的权重字典
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
gpu_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
(
group
=
seq_p_group
,
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
tensor_to_broadcast
=
gpu_weight_dict
[
key
]
if
is_leader
:
# rank0将CPU权重拷贝到目标GPU,准备广播
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
,
group
=
seq_p_group
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
is_leader
:
del
cpu_weight_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment