Unverified Commit 3efc43f5 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

fix: progress_callback (#483)

parent fcc2a411
......@@ -332,7 +332,7 @@ class DiTWorker(BaseWorker):
def run_dit(self):
self.runner.init_run()
assert self.runner.video_segment_num == 1, "DiTWorker only support single segment"
latents = self.runner.run_segment(total_steps=None)
latents = self.runner.run_segment()
self.runner.end_run()
return latents
......
......@@ -121,7 +121,7 @@ class BaseRunner(ABC):
def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx
def run_segment(self, total_steps=None):
def run_segment(self, segment_idx=0):
pass
def end_run_segment(self, segment_idx=None):
......
......@@ -133,20 +133,20 @@ class DefaultRunner(BaseRunner):
self.progress_callback = callback
@peak_memory_decorator
def run_segment(self, total_steps=None):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
def run_segment(self, segment_idx=0):
infer_steps = self.model.scheduler.infer_steps
for step_index in range(infer_steps):
# only for single segment, check stop signal every step
with ProfilingContext4DebugL1(
f"Run Dit every step",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
metrics_labels=[step_index + 1, total_steps],
metrics_labels=[step_index + 1, infer_steps],
):
if self.video_segment_num == 1:
self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
......@@ -158,13 +158,15 @@ class DefaultRunner(BaseRunner):
self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
current_step = segment_idx * infer_steps + step_index + 1
total_all_steps = self.video_segment_num * infer_steps
self.progress_callback((current_step / total_all_steps) * 100, 100)
return self.model.scheduler.latents
def run_step(self):
self.inputs = self.run_input_encoder()
self.run_main(total_steps=1)
self.run_main()
def end_run(self):
self.model.scheduler.clear()
......@@ -272,7 +274,7 @@ class DefaultRunner(BaseRunner):
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
@ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None):
def run_main(self):
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info)
......@@ -288,7 +290,7 @@ class DefaultRunner(BaseRunner):
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents = self.run_segment(total_steps=total_steps)
latents = self.run_segment(segment_idx)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
......
......@@ -753,14 +753,14 @@ class WanAudioRunner(WanRunner): # type:ignore
target_rank=1,
)
def run_main(self, total_steps=None):
def run_main(self):
try:
self.init_va_recorder()
self.init_va_reader()
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
if self.va_reader is None:
return super().run_main(total_steps)
return super().run_main()
self.va_reader.start()
rank, world_size = self.get_rank_and_world_size()
......@@ -794,7 +794,7 @@ class WanAudioRunner(WanRunner): # type:ignore
with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"):
fail_count = 0
self.init_run_segment(segment_idx, audio_array)
latents = self.run_segment(total_steps=None)
latents = self.run_segment(segment_idx)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment(segment_idx)
segment_idx += 1
......
......@@ -241,7 +241,7 @@ class WanSFMtxg2Runner(WanSFRunner):
self.inputs["current_actions"] = get_current_action(mode=self.config["mode"])
@ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None):
def run_main(self):
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info)
......@@ -260,7 +260,7 @@ class WanSFMtxg2Runner(WanSFRunner):
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents = self.run_segment(total_steps=total_steps)
latents = self.run_segment(segment_idx=segment_idx)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
......
......@@ -70,17 +70,16 @@ class WanSFRunner(WanRunner):
self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
@peak_memory_decorator
def run_segment(self, total_steps=None):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
def run_segment(self, segment_idx=0):
infer_steps = self.model.scheduler.infer_steps
for step_index in range(infer_steps):
# only for single segment, check stop signal every step
if self.video_segment_num == 1:
self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(seg_index=self.segment_idx, step_index=step_index, is_rerun=False)
self.model.scheduler.step_pre(seg_index=segment_idx, step_index=step_index, is_rerun=False)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs)
......@@ -89,6 +88,8 @@ class WanSFRunner(WanRunner):
self.model.scheduler.step_post()
if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
current_step = segment_idx * infer_steps + step_index + 1
total_all_steps = self.video_segment_num * infer_steps
self.progress_callback((current_step / total_all_steps) * 100, 100)
return self.model.scheduler.stream_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment