Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ec061565
Commit
ec061565
authored
Aug 11, 2025
by
gushiqiao
Browse files
Fix distribute offload bug.
parent
0cbd0544
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
20 deletions
+48
-20
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+2
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+27
-14
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+19
-4
No files found.
lightx2v/models/networks/wan/audio_adapter.py
View file @
ec061565
...
@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -157,8 +157,8 @@ class PerceiverAttentionCA(nn.Module):
v
=
v
,
v
=
v
,
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
max_seqlen_q
=
q_lens
.
max
(),
max_seqlen_q
=
q_lens
.
max
()
.
item
()
,
max_seqlen_k
=
k_lens
.
max
(),
max_seqlen_k
=
k_lens
.
max
()
.
item
()
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
...
...
lightx2v/models/networks/wan/model.py
View file @
ec061565
...
@@ -215,7 +215,7 @@ class WanModel:
...
@@ -215,7 +215,7 @@ class WanModel:
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
weight_dict
=
self
.
_
distribute_weights_multi_gpu
(
weight_dict
,
is_weight_loader
)
weight_dict
=
self
.
_
load_weights_distribute
(
weight_dict
,
is_weight_loader
)
self
.
original_weight_dict
=
weight_dict
self
.
original_weight_dict
=
weight_dict
else
:
else
:
...
@@ -234,48 +234,61 @@ class WanModel:
...
@@ -234,48 +234,61 @@ class WanModel:
del
self
.
original_weight_dict
del
self
.
original_weight_dict
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
_distribute_weights_multi_gpu
(
self
,
weight_dict
,
is_weight_loader
):
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
"""Distribute weights across multiple GPUs or CPUs based on offload config."""
global_src_rank
=
0
global_src_rank
=
0
# Determine target device for distribution
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
if
is_weight_loader
:
if
is_weight_loader
:
# Create metadata for broadcasting
meta_dict
=
{}
meta_dict
=
{}
for
key
,
tensor
in
weight_dict
.
items
():
for
key
,
tensor
in
weight_dict
.
items
():
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
meta_dict
[
key
]
=
{
"shape"
:
tensor
.
shape
,
"dtype"
:
tensor
.
dtype
}
# Broadcast metadata to all ranks
obj_list
=
[
meta_dict
]
obj_list
=
[
meta_dict
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
synced_meta_dict
=
obj_list
[
0
]
else
:
else
:
# Non-loader ranks receive metadata
obj_list
=
[
None
]
obj_list
=
[
None
]
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
dist
.
broadcast_object_list
(
obj_list
,
src
=
global_src_rank
)
synced_meta_dict
=
obj_list
[
0
]
synced_meta_dict
=
obj_list
[
0
]
# Create empty tensors on target device for all ranks
distributed_weight_dict
=
{}
distributed_weight_dict
=
{}
for
key
,
meta
in
synced_meta_dict
.
items
():
for
key
,
meta
in
synced_meta_dict
.
items
():
distributed_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
distributed_weight_dict
[
key
]
=
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
# Synchronize before broadcasting
if
target_device
==
"cuda"
:
if
target_device
==
"cuda"
:
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
else
:
dist
.
barrier
()
dist
.
barrier
()
# Broadcast weights from rank 0 to all ranks
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
if
is_weight_loader
:
if
is_weight_loader
:
# Copy weights to broadcast tensor
distributed_weight_dict
[
key
].
copy_
(
weight_dict
[
key
],
non_blocking
=
True
)
distributed_weight_dict
[
key
].
copy_
(
weight_dict
[
key
],
non_blocking
=
True
)
# Broadcast to all ranks
if
target_device
==
"cpu"
:
dist
.
broadcast
(
distributed_weight_dict
[
key
],
src
=
global_src_rank
)
if
is_weight_loader
:
gpu_tensor
=
distributed_weight_dict
[
key
].
cuda
()
dist
.
broadcast
(
gpu_tensor
,
src
=
global_src_rank
)
distributed_weight_dict
[
key
].
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
else
:
gpu_tensor
=
torch
.
empty_like
(
distributed_weight_dict
[
key
],
device
=
"cuda"
)
dist
.
broadcast
(
gpu_tensor
,
src
=
global_src_rank
)
distributed_weight_dict
[
key
].
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
if
distributed_weight_dict
[
key
].
is_pinned
():
distributed_weight_dict
[
key
].
copy_
(
distributed_weight_dict
[
key
],
non_blocking
=
True
)
else
:
dist
.
broadcast
(
distributed_weight_dict
[
key
],
src
=
global_src_rank
)
if
target_device
==
"cuda"
:
torch
.
cuda
.
synchronize
()
else
:
for
tensor
in
distributed_weight_dict
.
values
():
if
tensor
.
is_pinned
():
tensor
.
copy_
(
tensor
,
non_blocking
=
False
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
return
distributed_weight_dict
...
...
lightx2v/utils/utils.py
View file @
ec061565
...
@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
...
@@ -360,12 +360,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
synced_meta_dict
=
obj_list
[
0
]
synced_meta_dict
=
obj_list
[
0
]
if
cpu_offload
:
if
cpu_offload
:
# Multi-GPU + offload: weights on CPU
target_device
=
"cpu"
target_device
=
"cpu"
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
()
dist
.
barrier
()
else
:
else
:
# Multi-GPU + non-offload: weights on GPU
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
target_device
=
torch
.
device
(
f
"cuda:
{
current_rank
}
"
)
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
distributed_weight_dict
=
{
key
:
torch
.
empty
(
meta
[
"shape"
],
dtype
=
meta
[
"dtype"
],
device
=
target_device
)
for
key
,
meta
in
synced_meta_dict
.
items
()}
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
...
@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
...
@@ -374,11 +372,29 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
tensor_to_broadcast
=
distributed_weight_dict
[
key
]
tensor_to_broadcast
=
distributed_weight_dict
[
key
]
if
is_weight_loader
:
if
is_weight_loader
:
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
tensor_to_broadcast
.
copy_
(
cpu_weight_dict
[
key
],
non_blocking
=
True
)
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
cpu_offload
:
if
is_weight_loader
:
gpu_tensor
=
tensor_to_broadcast
.
cuda
()
dist
.
broadcast
(
gpu_tensor
,
src
=
src_global_rank
)
tensor_to_broadcast
.
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
else
:
gpu_tensor
=
torch
.
empty_like
(
tensor_to_broadcast
,
device
=
"cuda"
)
dist
.
broadcast
(
gpu_tensor
,
src
=
src_global_rank
)
tensor_to_broadcast
.
copy_
(
gpu_tensor
.
cpu
(),
non_blocking
=
True
)
del
gpu_tensor
torch
.
cuda
.
empty_cache
()
else
:
dist
.
broadcast
(
tensor_to_broadcast
,
src
=
src_global_rank
)
if
is_weight_loader
:
if
is_weight_loader
:
del
cpu_weight_dict
del
cpu_weight_dict
if
cpu_offload
:
torch
.
cuda
.
empty_cache
()
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
return
distributed_weight_dict
...
@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2):
...
@@ -388,7 +404,6 @@ def masks_like(tensor, zero=False, generator=None, p=0.2):
out
=
torch
.
ones_like
(
tensor
)
out
=
torch
.
ones_like
(
tensor
)
if
zero
:
if
zero
:
if
generator
is
not
None
:
if
generator
is
not
None
:
# 生成随机数判断是否需要修改
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
random_num
=
torch
.
rand
(
1
,
generator
=
generator
,
device
=
generator
.
device
).
item
()
if
random_num
<
p
:
if
random_num
<
p
:
out
[:,
0
]
=
torch
.
zeros_like
(
out
[:,
0
])
out
[:,
0
]
=
torch
.
zeros_like
(
out
[:,
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment