Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
0dc34857
Commit
0dc34857
authored
Aug 11, 2025
by
gushiqiao
Committed by
GitHub
Aug 11, 2025
Browse files
Fix distribute offload bug.
Fix distribute offload bug.
parents
0cbd0544
ec061565
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
20 deletions
+48
-20
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+2
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+27
-14
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+19
-4
No files found.
lightx2v/models/networks/wan/audio_adapter.py
View file @
0dc34857
...
@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module):
v
=
v
,
v
=
v
,
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
max_seqlen_q
=
q_lens
.
max
(),
max_seqlen_q
=
q_lens
.
max
()
.
item
()
,
max_seqlen_k
=
k_lens
.
max
(),
max_seqlen_k
=
k_lens
.
max
()
.
item
()
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
...
...
lightx2v/models/networks/wan/model.py
View file @
0dc34857
...
@@ -215,7 +215,7 @@ class WanModel:
...
@@ -215,7 +215,7 @@ class WanModel:
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
weight_dict
=
self
.
_
distribute_weights_multi_gpu
(
weight_dict
,
is_weight_loader
)
weight_dict
=
self
.
_
load_weights_distribute
(
weight_dict
,
is_weight_loader
)
self
.
original_weight_dict
=
weight_dict
self
.
original_weight_dict
=
weight_dict
else
:
else
:
...
@@ -234,48 +234,61 @@ class WanModel:
...
@@ -234,48 +234,61 @@ class WanModel:
del
self
.
original_weight_dict
del
self
.
original_weight_dict
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
_distribute_weights_multi_gpu
(
self
,
weight_dict
,
is_weight_loader
):
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank
=
0
global_src_rank
=
0
# Determine target device for distribution
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
if
is_weight_loader
:
if
is_weight_loader
:
# Create metadata for broadcasting
meta_dict
=
{}
meta_dict
=
{}
for
key
,
tensor
in
weight_dict
.
items
():
for
key
,
tensor
in
weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
# Broadcast metadata to all ranks
obj_list
=
[
meta_dict
]
obj_list
=
[
meta_dict
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
synced_meta_dict
=
obj_list
[
0
]
else
:
else
:
# Non-loader ranks receive metadata
obj_list
=
[
None
]
obj_list
=
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
synced_meta_dict
=
obj_list
[
0
]
# Create empty tensors on target device for all ranks
distributed_weight_dict
=
{}
distributed_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
for
key
,
meta
in
synced_meta_dict
.
items
():
distributed_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
distributed_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
# Synchronize before broadcasting
if
target_device
==
"cuda"
:
if
target_device
==
"cuda"
:
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
else
:
dist
.
barrier
()
dist
.
barrier
()
# Broadcast weights from rank 0 to all ranks
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
if
is_weight_loader
:
if
is_weight_loader
:
# Copy weights to broadcast tensor
distributed_weight_dict
[
key
].
copy_
(
weight_dict
[
key
],
non_blocking
=
True
)
distributed_weight_dict
[
key
].
copy_
(
weight_dict
[
key
],
non_blocking
=
True
)
# Broadcast to all ranks
if
target_device
==
"cpu"
:
dist
.
broadcast
(
distributed_weight_dict
[
key
],
src
=
global_src_rank
)
if
is_weight_loader
:
gpu_tensor
=
distributed_weight_dict
[
key
].
cuda
()
dist
.
broadcast
(
gpu_tensor
,
src
=
global_src_rank
)
distributed_weight_dict
[
key
].
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
else
:
gpu_tensor
=
torch
.
empty_like
(
distributed_weight_dict
[
key
],
device
=
"cuda"
)
dist
.
broadcast
(
gpu_tensor
,
src
=
global_src_rank
)
distributed_weight_dict
[
key
].
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
if
distributed_weight_dict
[
key
].
is_pinned
():
distributed_weight_dict
[
key
].
copy_
(
distributed_weight_dict
[
key
],
non_blocking
=
True
)
else
:
dist
.
broadcast
(
distributed_weight_dict
[
key
],
src
=
global_src_rank
)
if
target_device
==
"cuda"
:
torch
.
cuda
.
synchronize
()
else
:
for
tensor
in
distributed_weight_dict
.
values
():
if
tensor
.
is_pinned
():
tensor
.
copy_
(
tensor
,
non_blocking
=
False
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
return
distributed_weight_dict
...
...
lightx2v/utils/utils.py
View file @
0dc34857
...
@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
...
@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
synced_meta_dict
=
obj_list
[
0
]
synced_meta_dict
=
obj_list
[
0
]
if
cpu_offload
:
if
cpu_offload
:
# Multi-GPU + offload: weights on CPU
target_device
=
"cpu"
target_device
=
"cpu"
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
()
dist
.
barrier
()
else
:
else
:
# Multi-GPU + non-offload: weights on GPU
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
...
@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
...
@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
tensor_to_broadcast
=
distributed_weight_dict
[
key
]
tensor_to_broadcast
=
distributed_weight_dict
[
key
]
if
is_weight_loader
:
if
is_weight_loader
:
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
cpu_offload
:
if
is_weight_loader
:
gpu_tensor
=
tensor_to_broadcast
.
cuda
()
dist
.
broadcast
(
gpu_tensor
,
src
=
src_global_rank
)
tensor_to_broadcast
.
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
else
:
gpu_tensor
=
torch
.
empty_like
(
tensor_to_broadcast
,
device
=
"cuda"
)
dist
.
broadcast
(
gpu_tensor
,
src
=
src_global_rank
)
tensor_to_broadcast
.
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
else
:
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
is_weight_loader
:
if
is_weight_loader
:
del
cpu_weight_dict
del
cpu_weight_dict
if
cpu_offload
:
torch
.
cuda
.
empty_cache
()
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
return
distributed_weight_dict
...
@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2):
...
@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2):
out
=
torch
.
ones_like
(
tensor
)
out
=
torch
.
ones_like
(
tensor
)
if
zero
:
if
zero
:
if
generator
is
not
None
:
if
generator
is
not
None
:
# 生成随机数判断是否需要修改
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
if
random_num
<
p
:
out
[:,
0
]
=
torch
.
zeros_like
(
out
[:,
0
])
out
[:,
0
]
=
torch
.
zeros_like
(
out
[:,
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment