Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
b469676c
Commit
b469676c
authored
Aug 14, 2025
by
helloyongyang
Browse files
update log
parent
9e3680b7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
10 deletions
+13
-10
lightx2v/infer.py
lightx2v/infer.py
+3
-5
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+0
-2
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+10
-3
No files found.
lightx2v/infer.py
View file @
b469676c
import
argparse
import
argparse
import
json
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
...
@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
...
@@ -16,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
print_config
,
set_config
,
set_parallel_config
from
lightx2v.utils.utils
import
seed_all
from
lightx2v.utils.utils
import
seed_all
...
@@ -70,17 +69,16 @@ def main():
...
@@ -70,17 +69,16 @@ def main():
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
parser
.
add_argument
(
"--save_video_path"
,
type
=
str
,
default
=
"./output_lightx2v.mp4"
,
help
=
"The path to save video path/file"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logger
.
info
(
f
"args:
{
args
}
"
)
# set config
# set config
config
=
set_config
(
args
)
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
if
config
.
parallel
:
if
config
.
parallel
:
dist
.
init_process_group
(
backend
=
"nccl"
)
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
set_parallel_config
(
config
)
print_config
(
config
)
with
ProfilingContext
(
"Total Cost"
):
with
ProfilingContext
(
"Total Cost"
):
runner
=
init_runner
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
runner
.
run_pipeline
()
...
...
lightx2v/models/networks/wan/model.py
View file @
b469676c
...
@@ -256,8 +256,6 @@ class WanModel:
...
@@ -256,8 +256,6 @@ class WanModel:
if
target_device
==
"cuda"
:
if
target_device
==
"cuda"
:
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
dist
.
barrier
(
device_ids
=
[
torch
.
cuda
.
current_device
()])
else
:
dist
.
barrier
()
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
for
key
in
sorted
(
synced_meta_dict
.
keys
()):
if
is_weight_loader
:
if
is_weight_loader
:
...
...
lightx2v/utils/set_config.py
View file @
b469676c
...
@@ -69,9 +69,6 @@ def set_config(args):
...
@@ -69,9 +69,6 @@ def set_config(args):
def
set_parallel_config
(
config
):
def
set_parallel_config
(
config
):
if
config
.
parallel
:
if
config
.
parallel
:
if
not
dist
.
is_initialized
():
dist
.
init_process_group
(
backend
=
"nccl"
)
cfg_p_size
=
config
.
parallel
.
get
(
"cfg_p_size"
,
1
)
cfg_p_size
=
config
.
parallel
.
get
(
"cfg_p_size"
,
1
)
seq_p_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
seq_p_size
=
config
.
parallel
.
get
(
"seq_p_size"
,
1
)
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
assert
cfg_p_size
*
seq_p_size
==
dist
.
get_world_size
(),
f
"cfg_p_size * seq_p_size must be equal to world_size"
...
@@ -82,3 +79,13 @@ def set_parallel_config(config):
...
@@ -82,3 +79,13 @@ def set_parallel_config(config):
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
.
parallel
and
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
config
.
parallel
.
cfg_p_size
>
1
:
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
.
parallel
and
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
config
.
parallel
.
cfg_p_size
>
1
:
config
[
"cfg_parallel"
]
=
True
config
[
"cfg_parallel"
]
=
True
def
print_config
(
config
):
config_to_print
=
config
.
copy
()
config_to_print
.
pop
(
"device_mesh"
,
None
)
if
config
.
parallel
:
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config_to_print
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
else
:
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config_to_print
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment