Unverified Commit 0ad8ada3 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Fix reader import error (#585)

parent 9a765f9b
......@@ -5,10 +5,6 @@ import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE
......@@ -65,7 +61,7 @@ class VAController:
)
# how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", 1)
self.slice_frame = config.get("slice_frame", self.prev_frame_length)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
......@@ -78,6 +74,8 @@ class VAController:
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"):
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path,
......@@ -87,6 +85,8 @@ class VAController:
prev_frame=self.prev_frame_length,
)
else:
from lightx2v.deploy.common.va_recorder import VARecorder
self.recorder = VARecorder(
livestream_url=self.output_video_path,
fps=self.record_fps,
......@@ -103,6 +103,8 @@ class VAController:
prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir:
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
self.reader = OmniVAReader(
rank=self.rank,
world_size=self.world_size,
......@@ -115,6 +117,8 @@ class VAController:
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
)
else:
from lightx2v.deploy.common.va_reader import VAReader
self.reader = VAReader(
rank=self.rank,
world_size=self.world_size,
......
......@@ -721,6 +721,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def run_main(self):
try:
self.va_controller = None
self.va_controller = VAController(self)
logger.info(f"init va_recorder: {self.va_controller.recorder} and va_reader: {self.va_controller.reader}")
......@@ -776,7 +777,9 @@ class WanAudioRunner(WanRunner): # type:ignore
finally:
if hasattr(self.model, "inputs"):
self.end_run()
self.va_controller.clear()
if self.va_controller is not None:
self.va_controller.clear()
self.va_controller = None
@ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self):
......
......@@ -35,3 +35,6 @@ alibabacloud_dypnsapi20170525==1.2.2
redis==6.4.0
tos
decord
zmq
jsonschema
pymongo
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment