Unverified Commit 3efc43f5 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

fix: progress_callback (#483)

parent fcc2a411
...@@ -332,7 +332,7 @@ class DiTWorker(BaseWorker): ...@@ -332,7 +332,7 @@ class DiTWorker(BaseWorker):
def run_dit(self): def run_dit(self):
self.runner.init_run() self.runner.init_run()
assert self.runner.video_segment_num == 1, "DiTWorker only support single segment" assert self.runner.video_segment_num == 1, "DiTWorker only support single segment"
latents = self.runner.run_segment(total_steps=None) latents = self.runner.run_segment()
self.runner.end_run() self.runner.end_run()
return latents return latents
......
...@@ -121,7 +121,7 @@ class BaseRunner(ABC): ...@@ -121,7 +121,7 @@ class BaseRunner(ABC):
def init_run_segment(self, segment_idx): def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx self.segment_idx = segment_idx
def run_segment(self, total_steps=None): def run_segment(self, segment_idx=0):
pass pass
def end_run_segment(self, segment_idx=None): def end_run_segment(self, segment_idx=None):
......
...@@ -133,20 +133,20 @@ class DefaultRunner(BaseRunner): ...@@ -133,20 +133,20 @@ class DefaultRunner(BaseRunner):
self.progress_callback = callback self.progress_callback = callback
@peak_memory_decorator @peak_memory_decorator
def run_segment(self, total_steps=None): def run_segment(self, segment_idx=0):
if total_steps is None: infer_steps = self.model.scheduler.infer_steps
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(infer_steps):
# only for single segment, check stop signal every step # only for single segment, check stop signal every step
with ProfilingContext4DebugL1( with ProfilingContext4DebugL1(
f"Run Dit every step", f"Run Dit every step",
recorder_mode=GET_RECORDER_MODE(), recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration, metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
metrics_labels=[step_index + 1, total_steps], metrics_labels=[step_index + 1, infer_steps],
): ):
if self.video_segment_num == 1: if self.video_segment_num == 1:
self.check_stop() self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
with ProfilingContext4DebugL1("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
...@@ -158,13 +158,15 @@ class DefaultRunner(BaseRunner): ...@@ -158,13 +158,15 @@ class DefaultRunner(BaseRunner):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback: if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100) current_step = segment_idx * infer_steps + step_index + 1
total_all_steps = self.video_segment_num * infer_steps
self.progress_callback((current_step / total_all_steps) * 100, 100)
return self.model.scheduler.latents return self.model.scheduler.latents
def run_step(self): def run_step(self):
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
self.run_main(total_steps=1) self.run_main()
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
...@@ -272,7 +274,7 @@ class DefaultRunner(BaseRunner): ...@@ -272,7 +274,7 @@ class DefaultRunner(BaseRunner):
self.inputs["image_encoder_output"]["vae_encoder_out"] = None self.inputs["image_encoder_output"]["vae_encoder_out"] = None
@ProfilingContext4DebugL2("Run DiT") @ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None): def run_main(self):
self.init_run() self.init_run()
if self.config.get("compile", False): if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info) self.model.select_graph_for_compile(self.input_info)
...@@ -288,7 +290,7 @@ class DefaultRunner(BaseRunner): ...@@ -288,7 +290,7 @@ class DefaultRunner(BaseRunner):
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
# 2. main inference loop # 2. main inference loop
latents = self.run_segment(total_steps=total_steps) latents = self.run_segment(segment_idx)
# 3. vae decoder # 3. vae decoder
self.gen_video = self.run_vae_decoder(latents) self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing # 4. default do nothing
......
...@@ -753,14 +753,14 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -753,14 +753,14 @@ class WanAudioRunner(WanRunner): # type:ignore
target_rank=1, target_rank=1,
) )
def run_main(self, total_steps=None): def run_main(self):
try: try:
self.init_va_recorder() self.init_va_recorder()
self.init_va_reader() self.init_va_reader()
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}") logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
if self.va_reader is None: if self.va_reader is None:
return super().run_main(total_steps) return super().run_main()
self.va_reader.start() self.va_reader.start()
rank, world_size = self.get_rank_and_world_size() rank, world_size = self.get_rank_and_world_size()
...@@ -794,7 +794,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -794,7 +794,7 @@ class WanAudioRunner(WanRunner): # type:ignore
with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"): with ProfilingContext4DebugL1(f"stream segment end2end {segment_idx}"):
fail_count = 0 fail_count = 0
self.init_run_segment(segment_idx, audio_array) self.init_run_segment(segment_idx, audio_array)
latents = self.run_segment(total_steps=None) latents = self.run_segment(segment_idx)
self.gen_video = self.run_vae_decoder(latents) self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment(segment_idx) self.end_run_segment(segment_idx)
segment_idx += 1 segment_idx += 1
......
...@@ -241,7 +241,7 @@ class WanSFMtxg2Runner(WanSFRunner): ...@@ -241,7 +241,7 @@ class WanSFMtxg2Runner(WanSFRunner):
self.inputs["current_actions"] = get_current_action(mode=self.config["mode"]) self.inputs["current_actions"] = get_current_action(mode=self.config["mode"])
@ProfilingContext4DebugL2("Run DiT") @ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None): def run_main(self):
self.init_run() self.init_run()
if self.config.get("compile", False): if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info) self.model.select_graph_for_compile(self.input_info)
...@@ -260,7 +260,7 @@ class WanSFMtxg2Runner(WanSFRunner): ...@@ -260,7 +260,7 @@ class WanSFMtxg2Runner(WanSFRunner):
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
# 2. main inference loop # 2. main inference loop
latents = self.run_segment(total_steps=total_steps) latents = self.run_segment(segment_idx=segment_idx)
# 3. vae decoder # 3. vae decoder
self.gen_video = self.run_vae_decoder(latents) self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing # 4. default do nothing
......
...@@ -70,17 +70,16 @@ class WanSFRunner(WanRunner): ...@@ -70,17 +70,16 @@ class WanSFRunner(WanRunner):
self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
@peak_memory_decorator @peak_memory_decorator
def run_segment(self, total_steps=None): def run_segment(self, segment_idx=0):
if total_steps is None: infer_steps = self.model.scheduler.infer_steps
total_steps = self.model.scheduler.infer_steps for step_index in range(infer_steps):
for step_index in range(total_steps):
# only for single segment, check stop signal every step # only for single segment, check stop signal every step
if self.video_segment_num == 1: if self.video_segment_num == 1:
self.check_stop() self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
with ProfilingContext4DebugL1("step_pre"): with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(seg_index=self.segment_idx, step_index=step_index, is_rerun=False) self.model.scheduler.step_pre(seg_index=segment_idx, step_index=step_index, is_rerun=False)
with ProfilingContext4DebugL1("🚀 infer_main"): with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs) self.model.infer(self.inputs)
...@@ -89,6 +88,8 @@ class WanSFRunner(WanRunner): ...@@ -89,6 +88,8 @@ class WanSFRunner(WanRunner):
self.model.scheduler.step_post() self.model.scheduler.step_post()
if self.progress_callback: if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100) current_step = segment_idx * infer_steps + step_index + 1
total_all_steps = self.video_segment_num * infer_steps
self.progress_callback((current_step / total_all_steps) * 100, 100)
return self.model.scheduler.stream_output return self.model.scheduler.stream_output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment