Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d7206e69
"vscode:/vscode.git/clone" did not exist on "42982fc37f873d59cfabe9fab9058bb3f6bc8d69"
Commit
d7206e69
authored
Aug 05, 2025
by
helloyongyang
Browse files
fix gpu mem not balanced bug
parent
c0b36010
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
24 deletions
+21
-24
lightx2v/infer.py
lightx2v/infer.py
+10
-4
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+0
-3
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+11
-9
lightx2v/utils/set_config.py
lightx2v/utils/set_config.py
+0
-8
No files found.
lightx2v/infer.py
View file @
d7206e69
...
...
@@ -15,7 +15,7 @@ from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2D
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
ProfilingContext
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
prin
t_config
,
set_config
from
lightx2v.utils.set_config
import
se
t_config
,
set_
parallel_
config
from
lightx2v.utils.utils
import
seed_all
...
...
@@ -58,11 +58,17 @@ def main():
logger
.
info
(
f
"args:
{
args
}
"
)
# set config
config
=
set_config
(
args
)
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
if
"parallel"
in
config
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
set_parallel_config
(
config
)
with
ProfilingContext
(
"Total Cost"
):
config
=
set_config
(
args
)
print_config
(
config
)
runner
=
init_runner
(
config
)
runner
.
run_pipeline
()
# Clean up distributed process group
...
...
lightx2v/models/runners/default_runner.py
View file @
d7206e69
...
...
@@ -43,9 +43,6 @@ class DefaultRunner(BaseRunner):
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_t2v
def
set_init_device
(
self
):
if
self
.
config
.
parallel
:
cur_rank
=
dist
.
get_rank
()
torch
.
cuda
.
set_device
(
cur_rank
)
if
self
.
config
.
cpu_offload
:
self
.
init_device
=
torch
.
device
(
"cpu"
)
else
:
...
...
lightx2v/utils/profiler.py
View file @
d7206e69
...
...
@@ -3,6 +3,7 @@ import time
from
functools
import
wraps
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
...
...
@@ -12,9 +13,10 @@ class _ProfilingContext:
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
rank_info
=
""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
rank
=
torch
.
distributed
.
get_rank
()
self
.
rank_info
=
f
"Rank
{
rank
}
- "
if
dist
.
is_initialized
():
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
else
:
self
.
rank_info
=
"Single GPU"
def
__enter__
(
self
):
torch
.
cuda
.
synchronize
()
...
...
@@ -27,11 +29,11 @@ class _ProfilingContext:
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
peak_memory
=
torch
.
cuda
.
max_memory_allocated
()
/
(
1024
**
3
)
# 转换为GB
logger
.
info
(
f
"
{
self
.
rank_info
}
Function '
{
self
.
name
}
'
Peak Memory:
{
peak_memory
:.
2
f
}
GB"
)
logger
.
info
(
f
"
[Profile]
{
self
.
rank_info
}
-
{
self
.
name
}
Peak Memory:
{
peak_memory
:.
2
f
}
GB"
)
else
:
logger
.
info
(
f
"
{
self
.
rank_info
}
Function '
{
self
.
name
}
'
executed without GPU."
)
logger
.
info
(
f
"
[Profile]
{
self
.
rank_info
}
-
{
self
.
name
}
executed without GPU."
)
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
logger
.
info
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
logger
.
info
(
f
"[Profile]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
return
False
async
def
__aenter__
(
self
):
...
...
@@ -45,11 +47,11 @@ class _ProfilingContext:
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
peak_memory
=
torch
.
cuda
.
max_memory_allocated
()
/
(
1024
**
3
)
# 转换为GB
logger
.
info
(
f
"
{
self
.
rank_info
}
Function '
{
self
.
name
}
'
Peak Memory:
{
peak_memory
:.
2
f
}
GB"
)
logger
.
info
(
f
"
[Profile]
{
self
.
rank_info
}
-
{
self
.
name
}
Peak Memory:
{
peak_memory
:.
2
f
}
GB"
)
else
:
logger
.
info
(
f
"
{
self
.
rank_info
}
Function '
{
self
.
name
}
'
executed without GPU."
)
logger
.
info
(
f
"
[Profile]
{
self
.
rank_info
}
-
{
self
.
name
}
executed without GPU."
)
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
logger
.
info
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
logger
.
info
(
f
"[Profile]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
return
False
def
__call__
(
self
,
func
):
...
...
lightx2v/utils/set_config.py
View file @
d7206e69
...
...
@@ -61,8 +61,6 @@ def set_config(args):
logger
.
warning
(
f
"`num_frames - 1` has to be divisible by
{
config
.
vae_stride
[
0
]
}
. Rounding to the nearest number."
)
config
.
target_video_length
=
config
.
target_video_length
//
config
.
vae_stride
[
0
]
*
config
.
vae_stride
[
0
]
+
1
set_parallel_config
(
config
)
# parallel config
return
config
...
...
@@ -83,9 +81,3 @@ def set_parallel_config(config):
if
config
.
get
(
"enable_cfg"
,
False
)
and
config
.
parallel
and
config
.
parallel
.
get
(
"cfg_p_size"
,
False
)
and
config
.
parallel
.
cfg_p_size
>
1
:
config
[
"cfg_parallel"
]
=
True
def
print_config
(
config
):
config_to_print
=
config
.
copy
()
config_to_print
.
pop
(
"device_mesh"
,
None
)
# Remove device_mesh if it exists
logger
.
info
(
f
"config:
\n
{
json
.
dumps
(
config_to_print
,
ensure_ascii
=
False
,
indent
=
4
)
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment