"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "34f142797aff17af2a1c22d68529269d231cc8d4"
Commit 8aee6ed9 authored by helloyongyang's avatar helloyongyang
Browse files

update code

parent ab1b2790
...@@ -345,9 +345,9 @@ class DiTWorker(BaseWorker): ...@@ -345,9 +345,9 @@ class DiTWorker(BaseWorker):
def run_dit(self): def run_dit(self):
self.runner.init_run() self.runner.init_run()
assert self.runner.video_segment_num == 1, "DiTWorker only support single segment" assert self.runner.video_segment_num == 1, "DiTWorker only support single segment"
latents, generator = self.runner.run_segment(total_steps=None) latents = self.runner.run_segment(total_steps=None)
self.runner.end_run() self.runner.end_run()
return latents, generator return latents
class VaeDecoderWorker(BaseWorker): class VaeDecoderWorker(BaseWorker):
......
...@@ -172,6 +172,5 @@ class BaseRunner(ABC): ...@@ -172,6 +172,5 @@ class BaseRunner(ABC):
dist.broadcast(t, src=signal_rank) dist.broadcast(t, src=signal_rank)
stopped = t.item() stopped = t.item()
print(f"rank {rank} recv stopped: {stopped}")
if stopped == 1: if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior") raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
...@@ -128,7 +128,7 @@ class DefaultRunner(BaseRunner): ...@@ -128,7 +128,7 @@ class DefaultRunner(BaseRunner):
if self.progress_callback: if self.progress_callback:
self.progress_callback(((step_index + 1) / total_steps) * 100, 100) self.progress_callback(((step_index + 1) / total_steps) * 100, 100)
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents
def run_step(self): def run_step(self):
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
...@@ -224,12 +224,12 @@ class DefaultRunner(BaseRunner): ...@@ -224,12 +224,12 @@ class DefaultRunner(BaseRunner):
self.init_run() self.init_run()
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}") logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx}"): with ProfilingContext(f"segment end2end {segment_idx + 1}/{self.video_segment_num}"):
self.check_stop() self.check_stop()
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
# 2. main inference loop # 2. main inference loop
latents, generator = self.run_segment(total_steps=total_steps) latents = self.run_segment(total_steps=total_steps)
# 3. vae decoder # 3. vae decoder
self.gen_video = self.run_vae_decoder(latents) self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing # 4. default do nothing
......
...@@ -595,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -595,7 +595,7 @@ class WanAudioRunner(WanRunner): # type:ignore
with ProfilingContext4Debug(f"stream segment end2end {segment_idx}"): with ProfilingContext4Debug(f"stream segment end2end {segment_idx}"):
fail_count = 0 fail_count = 0
self.init_run_segment(segment_idx, audio_array) self.init_run_segment(segment_idx, audio_array)
latents, generator = self.run_segment(total_steps=None) latents = self.run_segment(total_steps=None)
self.gen_video = self.run_vae_decoder(latents) self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment() self.end_run_segment()
segment_idx += 1 segment_idx += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment