Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3b896f9c
Commit
3b896f9c
authored
Sep 18, 2025
by
gushiqiao
Committed by
GitHub
Sep 18, 2025
Browse files
[Fix] Fix distribute load model bug (#315)
parent
f085ede3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
13 deletions
+26
-13
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+4
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+16
-9
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+2
-0
lightx2v/utils/utils.py
lightx2v/utils/utils.py
+4
-4
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
3b896f9c
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -101,6 +102,7 @@ class MMWeight(MMWeightTemplate):
...
@@ -101,6 +102,7 @@ class MMWeight(MMWeightTemplate):
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
...
@@ -203,6 +205,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -203,6 +205,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
weight_scale_dtype
=
torch
.
float
weight_scale_dtype
=
torch
.
float
self
.
weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
if
dist
.
is_initialized
():
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
...
...
lightx2v/models/networks/wan/model.py
View file @
3b896f9c
...
@@ -122,14 +122,21 @@ class WanModel(CompiledMethodsMixin):
...
@@ -122,14 +122,21 @@ class WanModel(CompiledMethodsMixin):
# Single GPU mode
# Single GPU mode
return
True
return
True
elif
dist
.
is_initialized
():
elif
dist
.
is_initialized
():
# Multi-GPU mode, only rank 0 loads
if
self
.
config
.
get
(
"load_from_rank0"
,
False
):
if
dist
.
get_rank
()
==
0
:
# Multi-GPU mode, only rank 0 loads
logger
.
info
(
f
"Loading weights from
{
self
.
model_path
}
"
)
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"Loading weights from
{
self
.
model_path
}
"
)
return
True
else
:
return
True
return
True
return
False
return
False
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
if
self
.
device
.
type
==
"cuda"
and
dist
.
is_initialized
():
device
=
torch
.
device
(
"cuda:{}"
.
format
(
dist
.
get_rank
()))
else
:
device
=
self
.
device
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
...
@@ -147,8 +154,6 @@ class WanModel(CompiledMethodsMixin):
...
@@ -147,8 +154,6 @@ class WanModel(CompiledMethodsMixin):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
ckpt_path
=
self
.
dit_quantized_ckpt
ckpt_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading quant dit model from
{
ckpt_path
}
"
)
index_files
=
[
f
for
f
in
os
.
listdir
(
ckpt_path
)
if
f
.
endswith
(
".index.json"
)]
index_files
=
[
f
for
f
in
os
.
listdir
(
ckpt_path
)
if
f
.
endswith
(
".index.json"
)]
if
not
index_files
:
if
not
index_files
:
raise
FileNotFoundError
(
f
"No *.index.json found in
{
ckpt_path
}
"
)
raise
FileNotFoundError
(
f
"No *.index.json found in
{
ckpt_path
}
"
)
...
@@ -236,8 +241,8 @@ class WanModel(CompiledMethodsMixin):
...
@@ -236,8 +241,8 @@ class WanModel(CompiledMethodsMixin):
else
:
else
:
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
:
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
and
self
.
config
.
get
(
"load_from_rank0"
,
False
)
:
weight_dict
=
self
.
_load_weights_
distribute
(
weight_dict
,
is_weight_loader
)
weight_dict
=
self
.
_load_weights_
from_rank0
(
weight_dict
,
is_weight_loader
)
if
hasattr
(
self
,
"adapter_weights_dict"
):
if
hasattr
(
self
,
"adapter_weights_dict"
):
weight_dict
.
update
(
self
.
adapter_weights_dict
)
weight_dict
.
update
(
self
.
adapter_weights_dict
)
...
@@ -258,7 +263,8 @@ class WanModel(CompiledMethodsMixin):
...
@@ -258,7 +263,8 @@ class WanModel(CompiledMethodsMixin):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
def
_load_weights_distribute
(
self
,
weight_dict
,
is_weight_loader
):
def
_load_weights_from_rank0
(
self
,
weight_dict
,
is_weight_loader
):
logger
.
info
(
"Loading distributed weights"
)
global_src_rank
=
0
global_src_rank
=
0
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
target_device
=
"cpu"
if
self
.
cpu_offload
else
"cuda"
...
@@ -313,6 +319,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -313,6 +319,7 @@ class WanModel(CompiledMethodsMixin):
tensor
.
copy_
(
tensor
,
non_blocking
=
False
)
tensor
.
copy_
(
tensor
,
non_blocking
=
False
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
logger
.
info
(
f
"Weights distributed across
{
dist
.
get_world_size
()
}
devices on
{
target_device
}
"
)
return
distributed_weight_dict
return
distributed_weight_dict
def
_init_infer
(
self
):
def
_init_infer
(
self
):
...
...
lightx2v/models/runners/default_runner.py
View file @
3b896f9c
...
@@ -10,6 +10,7 @@ from requests.exceptions import RequestException
...
@@ -10,6 +10,7 @@ from requests.exceptions import RequestException
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
...
@@ -112,6 +113,7 @@ class DefaultRunner(BaseRunner):
...
@@ -112,6 +113,7 @@ class DefaultRunner(BaseRunner):
def
set_progress_callback
(
self
,
callback
):
def
set_progress_callback
(
self
,
callback
):
self
.
progress_callback
=
callback
self
.
progress_callback
=
callback
@
peak_memory_decorator
def
run_segment
(
self
,
total_steps
=
None
):
def
run_segment
(
self
,
total_steps
=
None
):
if
total_steps
is
None
:
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
total_steps
=
self
.
model
.
scheduler
.
infer_steps
...
...
lightx2v/utils/utils.py
View file @
3b896f9c
...
@@ -363,7 +363,7 @@ def load_pt_safetensors(in_path, remove_key):
...
@@ -363,7 +363,7 @@ def load_pt_safetensors(in_path, remove_key):
return
state_dict
return
state_dict
def
load_weights
(
checkpoint_path
,
cpu_offload
=
False
,
remove_key
=
None
):
def
load_weights
(
checkpoint_path
,
cpu_offload
=
False
,
remove_key
=
None
,
load_from_rank0
=
False
):
if
not
dist
.
is_initialized
():
if
not
dist
.
is_initialized
():
# Single GPU mode
# Single GPU mode
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
logger
.
info
(
f
"Loading weights from
{
checkpoint_path
}
"
)
...
@@ -371,10 +371,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
...
@@ -371,10 +371,10 @@ def load_weights(checkpoint_path, cpu_offload=False, remove_key=None):
return
cpu_weight_dict
return
cpu_weight_dict
# Multi-GPU mode
# Multi-GPU mode
is_weight_loader
=
Fals
e
is_weight_loader
=
Tru
e
current_rank
=
dist
.
get_rank
()
current_rank
=
dist
.
get_rank
()
if
current_rank
=
=
0
:
if
load_from_rank0
and
current_rank
!
=
0
:
is_weight_loader
=
Tru
e
is_weight_loader
=
Fals
e
cpu_weight_dict
=
{}
cpu_weight_dict
=
{}
if
is_weight_loader
:
if
is_weight_loader
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment