Commit 8aee6ed9 authored by helloyongyang's avatar helloyongyang
Browse files

update code

parent ab1b2790
......@@ -345,9 +345,9 @@ class DiTWorker(BaseWorker):
def run_dit(self):
self.runner.init_run()
assert self.runner.video_segment_num == 1, "DiTWorker only support single segment"
latents, generator = self.runner.run_segment(total_steps=None)
latents = self.runner.run_segment(total_steps=None)
self.runner.end_run()
return latents, generator
return latents
class VaeDecoderWorker(BaseWorker):
......
......@@ -172,6 +172,5 @@ class BaseRunner(ABC):
dist.broadcast(t, src=signal_rank)
stopped = t.item()
print(f"rank {rank} recv stopped: {stopped}")
if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
......@@ -128,7 +128,7 @@ class DefaultRunner(BaseRunner):
if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
return self.model.scheduler.latents, self.model.scheduler.generator
return self.model.scheduler.latents
def run_step(self):
self.inputs = self.run_input_encoder()
......@@ -224,12 +224,12 @@ class DefaultRunner(BaseRunner):
self.init_run()
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx}"):
with ProfilingContext(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
self.check_stop()
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents, generator = self.run_segment(total_steps=total_steps)
latents = self.run_segment(total_steps=total_steps)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
......
......@@ -595,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
with ProfilingContext4Debug(f"stream segment end2end {segment_idx}"):
fail_count = 0
self.init_run_segment(segment_idx, audio_array)
latents, generator = self.run_segment(total_steps=None)
latents = self.run_segment(total_steps=None)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment()
segment_idx += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment