va_controller.py 8.42 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import math
import os

import torch
import torch.distributed as dist
from loguru import logger

from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE


class NextControl:
    def __init__(self, action: str, data: any = None):
        # action: switch, data: prev_video tensor
        # action: wait, data: None
        # action: fetch, data: None
        self.action = action
        self.data = data


class VAController:
    def __init__(self, model_runner):
        self.reader = None
        self.recorder = None
        self.rank = 0
        self.world_size = 1
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
        self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
        self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
        self.init_recorder()
        self.init_reader(model_runner)

    def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
        self.audio_path = input_info.audio_path
        self.output_video_path = input_info.save_result_path
        if isinstance(self.output_video_path, dict):
            self.output_video_path = self.output_video_path["data"]

        self.audio_sr = config.get("audio_sr", 16000)
        self.target_fps = config.get("target_fps", 16)
        self.max_num_frames = config.get("target_video_length", 81)
        self.prev_frame_length = config.get("prev_frame_length", 5)

        self.record_fps = config.get("target_fps", 16)
        if "video_frame_interpolation" in config and has_vfi_model:
            self.record_fps = config["video_frame_interpolation"]["target_fps"]
        self.record_fps = config.get("record_fps", self.record_fps)

        self.tgt_h = input_info.target_shape[0]
        self.tgt_w = input_info.target_shape[1]
        self.record_h, self.record_w = self.tgt_h, self.tgt_w
        if "video_super_resolution" in config and has_vsr_model:
            _, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
                self.record_w,
                self.record_h,
                scale=config["video_super_resolution"]["scale"],
                multiple=128,
            )

        # how many frames to publish stream as a batch
        self.slice_frame = config.get("slice_frame", 1)
        # estimate the max infer seconds, for immediate switch with local omni
        slice_interval = self.slice_frame / self.record_fps
        est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
        self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
        self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1

    def init_recorder(self):
        if not self.output_video_path or self.rank != self.target_recorder_rank:
            return
        logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
        whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
        if whip_shared_path and self.output_video_path.startswith("http"):
            self.recorder = X264VARecorder(
                whip_shared_path=whip_shared_path,
                livestream_url=self.output_video_path,
                fps=self.record_fps,
                sample_rate=self.audio_sr,
                slice_frame=self.slice_frame,
                prev_frame=self.prev_frame_length,
            )
        else:
            self.recorder = VARecorder(
                livestream_url=self.output_video_path,
                fps=self.record_fps,
                sample_rate=self.audio_sr,
                slice_frame=self.slice_frame,
                prev_frame=self.prev_frame_length,
            )

    def init_reader(self, model_runner=None):
        if not isinstance(self.audio_path, dict):
            return
        assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
        segment_duration = self.max_num_frames / self.target_fps
        prev_duration = self.prev_frame_length / self.target_fps
        omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
        if omni_work_dir:
            self.reader = OmniVAReader(
                rank=self.rank,
                world_size=self.world_size,
                stream_url=self.audio_path["data"],
                sample_rate=self.audio_sr,
                segment_duration=segment_duration,
                prev_duration=prev_duration,
                target_rank=self.target_reader_rank,
                model_runner=model_runner,
                huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
            )
        else:
            self.reader = VAReader(
                rank=self.rank,
                world_size=self.world_size,
                stream_url=self.audio_path["data"],
                sample_rate=self.audio_sr,
                segment_duration=segment_duration,
                prev_duration=prev_duration,
                target_rank=self.target_reader_rank,
            )

    def start(self):
        self.reader.start()
        if self.rank == self.target_recorder_rank:
            assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
            self.recorder.start(self.record_w, self.record_h)
        if self.world_size > 1:
            dist.barrier()

    def next_control(self):
        if isinstance(self.reader, OmniVAReader):
            return self.omni_reader_next_control()
        return NextControl(action="fetch")

    def before_control(self):
        if isinstance(self.reader, OmniVAReader):
            self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
            self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
            self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)

    def omni_reader_next_control(self):
        immediate_switch = self.reader.get_immediate_switch()
        if immediate_switch == 1:
            # truncate the stream buffer to keep the max infer time length
            # and broadcast the prev video tensor to all ranks
            if self.rank == self.target_recorder_rank:
                logger.warning(f"runner recv immediate switch, truncate stream buffer")
                video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
                if video_tensor is not None:
                    self.flag_tensor.fill_(1)
                    self.prev_tensor.copy_(video_tensor)
                else:
                    self.flag_tensor.fill_(0)
            dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
            if self.flag_tensor.item() == 1:
                dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
                return NextControl(action="switch", data=self.prev_tensor)
        else:
            # get the length of stream buffer, broadcast to all ranks
            if self.rank == self.target_recorder_rank:
                stream_buffer_length = self.recorder.get_buffer_stream_size()
                self.len_tensor.copy_(stream_buffer_length)
            dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
            buffer_length = self.len_tensor.item()
            # stream buffer is enough, skip infer
            if buffer_length >= self.min_stay_queue_num:
                return NextControl(action="wait")
        return NextControl(action="fetch")

    def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
        if self.recorder.realtime:
            self.recorder.buffer_stream(images, audios, gen_video)
        else:
            self.recorder.pub_livestream(images, audios)

    def clear(self):
        self.len_tensor = None
        self.flag_tensor = None
        self.prev_tensor = None
        if self.reader is not None:
            self.reader.stop()
            self.reader = None
        if self.recorder is not None:
            self.recorder.stop()
            self.recorder = None

    def __del__(self):
        self.clear()