Unverified Commit 0ad8ada3 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Fix reader import error (#585)

parent 9a765f9b
...@@ -5,10 +5,6 @@ import torch ...@@ -5,10 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v_platform.base.global_var import AI_DEVICE
...@@ -65,7 +61,7 @@ class VAController: ...@@ -65,7 +61,7 @@ class VAController:
) )
# how many frames to publish stream as a batch # how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", 1) self.slice_frame = config.get("slice_frame", self.prev_frame_length)
# estimate the max infer seconds, for immediate switch with local omni # estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6) est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
...@@ -78,6 +74,8 @@ class VAController: ...@@ -78,6 +74,8 @@ class VAController:
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}") logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None) whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"): if whip_shared_path and self.output_video_path.startswith("http"):
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
self.recorder = X264VARecorder( self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path, whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path, livestream_url=self.output_video_path,
...@@ -87,6 +85,8 @@ class VAController: ...@@ -87,6 +85,8 @@ class VAController:
prev_frame=self.prev_frame_length, prev_frame=self.prev_frame_length,
) )
else: else:
from lightx2v.deploy.common.va_recorder import VARecorder
self.recorder = VARecorder( self.recorder = VARecorder(
livestream_url=self.output_video_path, livestream_url=self.output_video_path,
fps=self.record_fps, fps=self.record_fps,
...@@ -103,6 +103,8 @@ class VAController: ...@@ -103,6 +103,8 @@ class VAController:
prev_duration = self.prev_frame_length / self.target_fps prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None) omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir: if omni_work_dir:
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
self.reader = OmniVAReader( self.reader = OmniVAReader(
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
...@@ -115,6 +117,8 @@ class VAController: ...@@ -115,6 +117,8 @@ class VAController:
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None), huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
) )
else: else:
from lightx2v.deploy.common.va_reader import VAReader
self.reader = VAReader( self.reader = VAReader(
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
......
...@@ -721,6 +721,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -721,6 +721,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def run_main(self): def run_main(self):
try: try:
self.va_controller = None
self.va_controller = VAController(self) self.va_controller = VAController(self)
logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}") logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
...@@ -776,7 +777,9 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -776,7 +777,9 @@ class WanAudioRunner(WanRunner): # type:ignore
finally: finally:
if hasattr(self.model, "inputs"): if hasattr(self.model, "inputs"):
self.end_run() self.end_run()
if self.va_controller is not None:
self.va_controller.clear() self.va_controller.clear()
self.va_controller = None
@ProfilingContext4DebugL1("Process after vae decoder") @ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self): def process_images_after_vae_decoder(self):
......
...@@ -35,3 +35,6 @@ alibabacloud_dypnsapi20170525==1.2.2 ...@@ -35,3 +35,6 @@ alibabacloud_dypnsapi20170525==1.2.2
redis==6.4.0 redis==6.4.0
tos tos
decord decord
zmq
jsonschema
pymongo
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment