va_controller.py 8.49 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import math
import os

import torch
import torch.distributed as dist
from loguru import logger

from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE


class NextControl:
    def __init__(self, action: str, data: any = None):
        # action: switch, data: prev_video tensor
        # action: wait, data: None
        # action: fetch, data: None
        self.action = action
        self.data = data


class VAController:
    def __init__(self, model_runner):
        self.reader = None
        self.recorder = None
        self.rank = 0
        self.world_size = 1
        if dist.is_initialized():
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
        self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
        self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
        self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
        self.init_recorder()
        self.init_reader(model_runner)

    def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
        self.audio_path = input_info.audio_path
        self.output_video_path = input_info.save_result_path
        if isinstance(self.output_video_path, dict):
            self.output_video_path = self.output_video_path["data"]

        self.audio_sr = config.get("audio_sr", 16000)
        self.target_fps = config.get("target_fps", 16)
        self.max_num_frames = config.get("target_video_length", 81)
        self.prev_frame_length = config.get("prev_frame_length", 5)

        self.record_fps = config.get("target_fps", 16)
        if "video_frame_interpolation" in config and has_vfi_model:
            self.record_fps = config["video_frame_interpolation"]["target_fps"]
        self.record_fps = config.get("record_fps", self.record_fps)

        self.tgt_h = input_info.target_shape[0]
        self.tgt_w = input_info.target_shape[1]
        self.record_h, self.record_w = self.tgt_h, self.tgt_w
        if "video_super_resolution" in config and has_vsr_model:
            _, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
                self.record_w,
                self.record_h,
                scale=config["video_super_resolution"]["scale"],
                multiple=128,
            )

        # how many frames to publish stream as a batch
LiangLiu's avatar
LiangLiu committed
64
        self.slice_frame = config.get("slice_frame", self.prev_frame_length)
LiangLiu's avatar
LiangLiu committed
65
66
67
68
69
70
71
72
73
74
75
76
        # estimate the max infer seconds, for immediate switch with local omni
        slice_interval = self.slice_frame / self.record_fps
        est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
        self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
        self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1

    def init_recorder(self):
        if not self.output_video_path or self.rank != self.target_recorder_rank:
            return
        logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
        whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
        if whip_shared_path and self.output_video_path.startswith("http"):
LiangLiu's avatar
LiangLiu committed
77
78
            from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder

LiangLiu's avatar
LiangLiu committed
79
80
81
82
83
84
85
86
87
            self.recorder = X264VARecorder(
                whip_shared_path=whip_shared_path,
                livestream_url=self.output_video_path,
                fps=self.record_fps,
                sample_rate=self.audio_sr,
                slice_frame=self.slice_frame,
                prev_frame=self.prev_frame_length,
            )
        else:
LiangLiu's avatar
LiangLiu committed
88
89
            from lightx2v.deploy.common.va_recorder import VARecorder

LiangLiu's avatar
LiangLiu committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            self.recorder = VARecorder(
                livestream_url=self.output_video_path,
                fps=self.record_fps,
                sample_rate=self.audio_sr,
                slice_frame=self.slice_frame,
                prev_frame=self.prev_frame_length,
            )

    def init_reader(self, model_runner=None):
        if not isinstance(self.audio_path, dict):
            return
        assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
        segment_duration = self.max_num_frames / self.target_fps
        prev_duration = self.prev_frame_length / self.target_fps
        omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
        if omni_work_dir:
LiangLiu's avatar
LiangLiu committed
106
107
            from lightx2v.deploy.common.va_reader_omni import OmniVAReader

LiangLiu's avatar
LiangLiu committed
108
109
110
111
112
113
114
115
116
117
118
119
            self.reader = OmniVAReader(
                rank=self.rank,
                world_size=self.world_size,
                stream_url=self.audio_path["data"],
                sample_rate=self.audio_sr,
                segment_duration=segment_duration,
                prev_duration=prev_duration,
                target_rank=self.target_reader_rank,
                model_runner=model_runner,
                huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
            )
        else:
LiangLiu's avatar
LiangLiu committed
120
121
            from lightx2v.deploy.common.va_reader import VAReader

LiangLiu's avatar
LiangLiu committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            self.reader = VAReader(
                rank=self.rank,
                world_size=self.world_size,
                stream_url=self.audio_path["data"],
                sample_rate=self.audio_sr,
                segment_duration=segment_duration,
                prev_duration=prev_duration,
                target_rank=self.target_reader_rank,
            )

    def start(self):
        self.reader.start()
        if self.rank == self.target_recorder_rank:
            assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
            self.recorder.start(self.record_w, self.record_h)
        if self.world_size > 1:
            dist.barrier()

    def next_control(self):
        if isinstance(self.reader, OmniVAReader):
            return self.omni_reader_next_control()
        return NextControl(action="fetch")

    def before_control(self):
        if isinstance(self.reader, OmniVAReader):
            self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
            self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
            self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)

    def omni_reader_next_control(self):
        immediate_switch = self.reader.get_immediate_switch()
        if immediate_switch == 1:
            # truncate the stream buffer to keep the max infer time length
            # and broadcast the prev video tensor to all ranks
            if self.rank == self.target_recorder_rank:
                logger.warning(f"runner recv immediate switch, truncate stream buffer")
                video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
                if video_tensor is not None:
                    self.flag_tensor.fill_(1)
                    self.prev_tensor.copy_(video_tensor)
                else:
                    self.flag_tensor.fill_(0)
            dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
            if self.flag_tensor.item() == 1:
                dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
                return NextControl(action="switch", data=self.prev_tensor)
        else:
            # get the length of stream buffer, broadcast to all ranks
            if self.rank == self.target_recorder_rank:
                stream_buffer_length = self.recorder.get_buffer_stream_size()
                self.len_tensor.copy_(stream_buffer_length)
            dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
            buffer_length = self.len_tensor.item()
            # stream buffer is enough, skip infer
            if buffer_length >= self.min_stay_queue_num:
                return NextControl(action="wait")
        return NextControl(action="fetch")

    def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
        if self.recorder.realtime:
            self.recorder.buffer_stream(images, audios, gen_video)
        else:
            self.recorder.pub_livestream(images, audios)

    def clear(self):
        self.len_tensor = None
        self.flag_tensor = None
        self.prev_tensor = None
        if self.reader is not None:
            self.reader.stop()
            self.reader = None
        if self.recorder is not None:
            self.recorder.stop()
            self.recorder = None

    def __del__(self):
        self.clear()