va_reader.py 10.4 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import os
import queue
import signal
import subprocess
import threading
import time
import traceback

import numpy as np
import torch
import torch.distributed as dist
from loguru import logger


class VAReader:
    def __init__(
        self,
        rank: int,
        world_size: int,
        stream_url: str,
        segment_duration: float = 5.0,
        sample_rate: int = 16000,
        audio_channels: int = 1,
        buffer_size: int = 1,
        prev_duration: float = 0.3125,
        target_rank: int = 0,
    ):
        self.rank = rank
        self.world_size = world_size
        self.stream_url = stream_url
        self.segment_duration = segment_duration
        self.sample_rate = sample_rate
        self.audio_channels = audio_channels
        self.prev_duration = prev_duration
        # int16 = 2 bytes
        self.chunk_size = int(self.segment_duration * self.sample_rate) * 2
        self.prev_size = int(self.prev_duration * self.sample_rate) * 2
        self.prev_chunk = None
        self.buffer_size = buffer_size

        self.audio_queue = queue.Queue(maxsize=self.buffer_size)
        self.audio_thread = None
        self.ffmpeg_process = None
        self.bytes_buffer = bytearray()

        self.target_rank = target_rank % self.world_size

        self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
        self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda")

        logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
        logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")

    def start(self):
        if self.rank == self.target_rank:
            if self.stream_url.startswith("rtmp://"):
                self.start_ffmpeg_process_rtmp()
            elif self.stream_url.startswith("http"):
                self.start_ffmpeg_process_whep()
            else:
                raise Exception(f"Unsupported stream URL: {self.stream_url}")
            self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True)
            self.audio_thread.start()
            logger.info(f"VAReader {self.rank}/{self.world_size} started successfully")
        else:
            logger.info(f"VAReader {self.rank}/{self.world_size} wait only")
        if self.world_size > 1:
            logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier")
            dist.barrier()
            logger.info(f"VAReader {self.rank}/{self.world_size} end barrier")

    def start_ffmpeg_process_rtmp(self):
        """Start ffmpeg process read audio from stream"""
        ffmpeg_cmd = [
            "/opt/conda/bin/ffmpeg",
            "-i",
            self.stream_url,
            "-vn",
            # "-acodec",
            # "pcm_s16le",
            "-ar",
            str(self.sample_rate),
            "-ac",
            str(self.audio_channels),
            "-f",
            "s16le",
            "-",
        ]
        try:
            self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
            logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
            logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
        except Exception as e:
            logger.error(f"Failed to start FFmpeg process: {e}")
            raise

    def start_ffmpeg_process_whep(self):
        """Start gstream process read audio from stream"""
        ffmpeg_cmd = [
            "gst-launch-1.0",
            "-q",
            "whepsrc",
            f"whep-endpoint={self.stream_url}",
            "video-caps=none",
            "!rtpopusdepay",
            "!opusdec",
            "plc=false",
            "!audioconvert",
            "!audioresample",
            f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}",
            "!fdsink",
            "fd=1",
        ]
        try:
            self.ffmpeg_process = subprocess.Popen(
                ffmpeg_cmd,
                stdout=subprocess.PIPE,
                # stderr=subprocess.PIPE,
                bufsize=0,
            )
            logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
            logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
        except Exception as e:
            logger.error(f"Failed to start FFmpeg process: {e}")
            raise

    def audio_worker(self):
        logger.info("Audio pull worker thread started")
        try:
            while True:
                if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
                    logger.warning("FFmpeg process exited, audio worker thread stopped")
                    break
                self.fetch_audio_data()
                time.sleep(0.01)
        except:  # noqa
            logger.error(f"Audio pull worker error: {traceback.format_exc()}")
        finally:
            logger.warning("Audio pull worker thread stopped")

    def fetch_audio_data(self):
        """Fetch audio data from ffmpeg process"""
        try:
            audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size)
            if not audio_bytes:
                return
            self.bytes_buffer.extend(audio_bytes)
            # logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")

            if len(self.bytes_buffer) >= self.chunk_size:
                audio_data = self.bytes_buffer[: self.chunk_size]
                self.bytes_buffer = self.bytes_buffer[self.chunk_size :]

                # first chunk, read original 81 frames
                # for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
                if self.prev_chunk is None:
                    logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}")
                    self.chunk_size -= self.prev_size
                else:
                    audio_data = self.prev_chunk + audio_data
                self.prev_chunk = audio_data[-self.prev_size :]

                try:
                    self.audio_queue.put_nowait(audio_data)
                except queue.Full:
                    logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk")
                    self.audio_queue.get_nowait()
                    self.audio_queue.put_nowait(audio_data)
                logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}")

        except:  # noqa
            logger.error(f"Fetch audio data error: {traceback.format_exc()}")

    def braodcast_audio_data(self, audio_data):
        if self.rank == self.target_rank:
            if audio_data is None:
                self.flag_tensor.fill_(0)
            else:
                self.flag_tensor.fill_(1)
                self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
                logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")

        dist.broadcast(self.flag_tensor, src=self.target_rank)
        if self.flag_tensor.item() == 0:
            return None

        dist.broadcast(self.audio_tensor, src=self.target_rank)
        if self.rank != self.target_rank:
            logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
            audio_data = self.audio_tensor.cpu().numpy().tobytes()
        return audio_data

    def bytes_to_ndarray(self, audio_data):
        if audio_data is None:
            return None
        audio_data = np.frombuffer(audio_data, dtype=np.int16)
        audio_data = audio_data.astype(np.float32) / 32768.0
        logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
        return audio_data

    def get_audio_segment(self, timeout: float = 1.0):
        audio_data = None
        if self.rank == self.target_rank:
            try:
                audio_data = self.audio_queue.get(timeout=timeout)
            except:  # noqa
                logger.warning(f"Failed to get audio segment: {traceback.format_exc()}")
        if self.world_size > 1:
            audio_data = self.braodcast_audio_data(audio_data)
        audio_data = self.bytes_to_ndarray(audio_data)
        return audio_data

    def stop(self):
        # Stop ffmpeg process
        if self.ffmpeg_process:
            self.ffmpeg_process.send_signal(signal.SIGINT)
            try:
                self.ffmpeg_process.wait(timeout=5)
            except subprocess.TimeoutExpired:
                self.ffmpeg_process.kill()
            logger.warning("FFmpeg reader process stopped")

        # Wait for threads to finish
        if self.audio_thread and self.audio_thread.is_alive():
            self.audio_thread.join(timeout=5)
            if self.audio_thread.is_alive():
                logger.error("Audio pull thread did not stop gracefully")

        while self.audio_queue and self.audio_queue.qsize() > 0:
            self.audio_queue.get_nowait()
        self.audio_queue = None
        logger.warning("Audio pull queue cleaned")

    def __del__(self):
        self.stop()


if __name__ == "__main__":
    WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
    RANK = int(os.environ.get("RANK", 0))
    if WORLD_SIZE > 1:
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(dist.get_rank())
        logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")

    reader = VAReader(
        RANK,
        WORLD_SIZE,
        # "rtmp://localhost/live/test_audio",
        "https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000",
        segment_duration=1.0,
        sample_rate=16000,
        audio_channels=1,
        prev_duration=1 / 16,
    )
    reader.start()
    fail_count = 0
    max_fail_count = 2

    try:
        while True:
            audio_data = reader.get_audio_segment(timeout=2)
            if audio_data is not None:
                # logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
                fail_count = 0
            else:
                fail_count += 1
                if fail_count > max_fail_count:
                    logger.warning("Failed to get audio chunk, stop reader")
                    reader.stop()
                    break
            time.sleep(0.95)
    finally:
        reader.stop()