Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3efc43f5
Unverified
Commit
3efc43f5
authored
Nov 20, 2025
by
PengGao
Committed by
GitHub
Nov 20, 2025
Browse files
fix: progress_callback (#483)
parent
fcc2a411
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
24 deletions
+27
-24
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+1
-1
lightx2v/models/runners/base_runner.py
lightx2v/models/runners/base_runner.py
+1
-1
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+12
-10
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+3
-3
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
+2
-2
lightx2v/models/runners/wan/wan_sf_runner.py
lightx2v/models/runners/wan/wan_sf_runner.py
+8
-7
No files found.
lightx2v/deploy/worker/hub.py
View file @
3efc43f5
...
...
@@ -332,7 +332,7 @@ class DiTWorker(BaseWorker):
def
run_dit
(
self
):
self
.
runner
.
init_run
()
assert
self
.
runner
.
video_segment_num
==
1
,
"DiTWorker only support single segment"
latents
=
self
.
runner
.
run_segment
(
total_steps
=
None
)
latents
=
self
.
runner
.
run_segment
()
self
.
runner
.
end_run
()
return
latents
...
...
lightx2v/models/runners/base_runner.py
View file @
3efc43f5
...
...
@@ -121,7 +121,7 @@ class BaseRunner(ABC):
def
init_run_segment
(
self
,
segment_idx
):
self
.
segment_idx
=
segment_idx
def
run_segment
(
self
,
total_steps
=
None
):
def
run_segment
(
self
,
segment_idx
=
0
):
pass
def
end_run_segment
(
self
,
segment_idx
=
None
):
...
...
lightx2v/models/runners/default_runner.py
View file @
3efc43f5
...
...
@@ -133,20 +133,20 @@ class DefaultRunner(BaseRunner):
self
.
progress_callback
=
callback
@
peak_memory_decorator
def
run_segment
(
self
,
total_steps
=
None
):
i
f
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total
_steps
):
def
run_segment
(
self
,
segment_idx
=
0
):
i
nfer_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
infer
_steps
):
# only for single segment, check stop signal every step
with
ProfilingContext4DebugL1
(
f
"Run Dit every step"
,
recorder_mode
=
GET_RECORDER_MODE
(),
metrics_func
=
monitor_cli
.
lightx2v_run_per_step_dit_duration
,
metrics_labels
=
[
step_index
+
1
,
total
_steps
],
metrics_labels
=
[
step_index
+
1
,
infer
_steps
],
):
if
self
.
video_segment_num
==
1
:
self
.
check_stop
()
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total
_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
infer
_steps
}
"
)
with
ProfilingContext4DebugL1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
...
...
@@ -158,13 +158,15 @@ class DefaultRunner(BaseRunner):
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
self
.
progress_callback
(((
step_index
+
1
)
/
total_steps
)
*
100
,
100
)
current_step
=
segment_idx
*
infer_steps
+
step_index
+
1
total_all_steps
=
self
.
video_segment_num
*
infer_steps
self
.
progress_callback
((
current_step
/
total_all_steps
)
*
100
,
100
)
return
self
.
model
.
scheduler
.
latents
def
run_step
(
self
):
self
.
inputs
=
self
.
run_input_encoder
()
self
.
run_main
(
total_steps
=
1
)
self
.
run_main
()
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
...
...
@@ -272,7 +274,7 @@ class DefaultRunner(BaseRunner):
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
def
run_main
(
self
):
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
...
...
@@ -288,7 +290,7 @@ class DefaultRunner(BaseRunner):
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
# 2. main inference loop
latents
=
self
.
run_segment
(
total_steps
=
total_steps
)
latents
=
self
.
run_segment
(
segment_idx
)
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
3efc43f5
...
...
@@ -753,14 +753,14 @@ class WanAudioRunner(WanRunner): # type:ignore
target_rank
=
1
,
)
def
run_main
(
self
,
total_steps
=
None
):
def
run_main
(
self
):
try
:
self
.
init_va_recorder
()
self
.
init_va_reader
()
logger
.
info
(
f
"init va_recorder:
{
self
.
va_recorder
}
and va_reader:
{
self
.
va_reader
}
"
)
if
self
.
va_reader
is
None
:
return
super
().
run_main
(
total_steps
)
return
super
().
run_main
()
self
.
va_reader
.
start
()
rank
,
world_size
=
self
.
get_rank_and_world_size
()
...
...
@@ -794,7 +794,7 @@ class WanAudioRunner(WanRunner): # type:ignore
with
ProfilingContext4DebugL1
(
f
"stream segment end2end
{
segment_idx
}
"
):
fail_count
=
0
self
.
init_run_segment
(
segment_idx
,
audio_array
)
latents
=
self
.
run_segment
(
total_steps
=
None
)
latents
=
self
.
run_segment
(
segment_idx
)
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
self
.
end_run_segment
(
segment_idx
)
segment_idx
+=
1
...
...
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
View file @
3efc43f5
...
...
@@ -241,7 +241,7 @@ class WanSFMtxg2Runner(WanSFRunner):
self
.
inputs
[
"current_actions"
]
=
get_current_action
(
mode
=
self
.
config
[
"mode"
])
@
ProfilingContext4DebugL2
(
"Run DiT"
)
def
run_main
(
self
,
total_steps
=
None
):
def
run_main
(
self
):
self
.
init_run
()
if
self
.
config
.
get
(
"compile"
,
False
):
self
.
model
.
select_graph_for_compile
(
self
.
input_info
)
...
...
@@ -260,7 +260,7 @@ class WanSFMtxg2Runner(WanSFRunner):
# 1. default do nothing
self
.
init_run_segment
(
segment_idx
)
# 2. main inference loop
latents
=
self
.
run_segment
(
total_steps
=
total_steps
)
latents
=
self
.
run_segment
(
segment_idx
=
segment_idx
)
# 3. vae decoder
self
.
gen_video
=
self
.
run_vae_decoder
(
latents
)
# 4. default do nothing
...
...
lightx2v/models/runners/wan/wan_sf_runner.py
View file @
3efc43f5
...
...
@@ -70,17 +70,16 @@ class WanSFRunner(WanRunner):
self
.
gen_video_final
=
torch
.
cat
([
self
.
gen_video_final
,
self
.
gen_video
],
dim
=
0
)
if
self
.
gen_video_final
is
not
None
else
self
.
gen_video
@
peak_memory_decorator
def
run_segment
(
self
,
total_steps
=
None
):
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
def
run_segment
(
self
,
segment_idx
=
0
):
infer_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
infer_steps
):
# only for single segment, check stop signal every step
if
self
.
video_segment_num
==
1
:
self
.
check_stop
()
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total
_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
infer
_steps
}
"
)
with
ProfilingContext4DebugL1
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
self
.
segment_idx
,
step_index
=
step_index
,
is_rerun
=
False
)
self
.
model
.
scheduler
.
step_pre
(
seg_index
=
segment_idx
,
step_index
=
step_index
,
is_rerun
=
False
)
with
ProfilingContext4DebugL1
(
"🚀 infer_main"
):
self
.
model
.
infer
(
self
.
inputs
)
...
...
@@ -89,6 +88,8 @@ class WanSFRunner(WanRunner):
self
.
model
.
scheduler
.
step_post
()
if
self
.
progress_callback
:
self
.
progress_callback
(((
step_index
+
1
)
/
total_steps
)
*
100
,
100
)
current_step
=
segment_idx
*
infer_steps
+
step_index
+
1
total_all_steps
=
self
.
video_segment_num
*
infer_steps
self
.
progress_callback
((
current_step
/
total_all_steps
)
*
100
,
100
)
return
self
.
model
.
scheduler
.
stream_output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment