audio_separator.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# -*- coding: utf-8 -*-
"""
Audio Source Separation Module
Separates different voice tracks in audio, supports multi-person audio separation
"""

import base64
import io
import os
import tempfile
import traceback
from collections import defaultdict
from typing import Dict, Optional, Union

import torch
import torchaudio
from loguru import logger

# Import pyannote.audio for speaker diarization
from pyannote.audio import Audio, Pipeline

LiangLiu's avatar
LiangLiu committed
22
23
24
25
26
27
28
_origin_torch_load = torch.load


def our_torch_load(checkpoint_file, *args, **kwargs):
    kwargs["weights_only"] = False
    return _origin_torch_load(checkpoint_file, *args, **kwargs)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

class AudioSeparator:
    """
    Audio separator for separating different voice tracks in audio using pyannote.audio
    Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
    """

    def __init__(
        self,
        model_path: str = None,
        device: str = None,
        sample_rate: int = 16000,
    ):
        """
        Initialize audio separator

        Args:
            model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
            device: Device ('cpu', 'cuda', etc.), None for auto selection
            sample_rate: Target sample rate, default 16000
        """
        self.sample_rate = sample_rate
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self._init_pyannote(model_path)

    def _init_pyannote(self, model_path: str = None):
        """Initialize pyannote.audio pipeline"""
        try:
            huggingface_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
            model_name = model_path or "pyannote/speaker-diarization-community-1"

            try:
LiangLiu's avatar
LiangLiu committed
61
                torch.load = our_torch_load
62
63
64
65
66
67
68
69
70
71
                # Try loading with token if available
                if huggingface_token:
                    self.pipeline = Pipeline.from_pretrained(model_name, token=huggingface_token)
                else:
                    # Try without token (may work for public models)
                    self.pipeline = Pipeline.from_pretrained(model_name)
            except Exception as e:
                if "gated" in str(e).lower() or "token" in str(e).lower():
                    raise RuntimeError(f"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable: {e}")
                raise RuntimeError(f"Failed to load pyannote model: {e}")
LiangLiu's avatar
LiangLiu committed
72
73
            finally:
                torch.load = _origin_torch_load
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376

            # Move pipeline to specified device
            if self.device:
                self.pipeline.to(torch.device(self.device))

            # Initialize Audio helper for waveform loading
            self.pyannote_audio = Audio()

            logger.info("Initialized pyannote.audio speaker diarization pipeline")
        except Exception as e:
            logger.error(f"Failed to initialize pyannote: {e}")
            raise RuntimeError(f"Failed to initialize pyannote.audio pipeline: {e}")

    def separate_speakers(
        self,
        audio_path: Union[str, bytes],
        num_speakers: Optional[int] = None,
        min_speakers: int = 1,
        max_speakers: int = 5,
    ) -> Dict:
        """
        Separate different speakers in audio

        Args:
            audio_path: Audio file path or bytes data
            num_speakers: Specified number of speakers, None for auto detection
            min_speakers: Minimum number of speakers
            max_speakers: Maximum number of speakers

        Returns:
            Dict containing:
                - speakers: List of speaker audio segments, each containing:
                    - speaker_id: Speaker ID (0, 1, 2, ...)
                    - audio: torch.Tensor audio data [channels, samples]
                    - segments: List of (start_time, end_time) tuples
                    - sample_rate: Sample rate
        """
        try:
            # Load audio
            if isinstance(audio_path, bytes):
                # 尝试从字节数据推断音频格式
                # 检查是否是 WAV 格式(RIFF 头)
                is_wav = audio_path[:4] == b"RIFF" and audio_path[8:12] == b"WAVE"
                # 检查是否是 MP3 格式(ID3 或 MPEG 头)
                is_mp3 = audio_path[:3] == b"ID3" or audio_path[:2] == b"\xff\xfb" or audio_path[:2] == b"\xff\xf3"

                # 根据格式选择后缀
                if is_wav:
                    suffix = ".wav"
                elif is_mp3:
                    suffix = ".mp3"
                else:
                    # 默认尝试 WAV,如果失败会抛出错误
                    suffix = ".wav"

                with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
                    tmp_file.write(audio_path)
                    tmp_audio_path = tmp_file.name
                try:
                    result = self._separate_speakers_internal(tmp_audio_path, num_speakers, min_speakers, max_speakers)
                finally:
                    # 确保临时文件被删除
                    try:
                        os.unlink(tmp_audio_path)
                    except Exception as e:
                        logger.warning(f"Failed to delete temp file {tmp_audio_path}: {e}")
                return result
            else:
                return self._separate_speakers_internal(audio_path, num_speakers, min_speakers, max_speakers)

        except Exception as e:
            logger.error(f"Speaker separation failed: {traceback.format_exc()}")
            raise RuntimeError(f"Audio separation error: {e}")

    def _separate_speakers_internal(
        self,
        audio_path: str,
        num_speakers: Optional[int] = None,
        min_speakers: int = 1,
        max_speakers: int = 5,
    ) -> Dict:
        """Internal method: execute speaker separation"""

        # Load audio
        waveform, original_sr = torchaudio.load(audio_path)
        if original_sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(original_sr, self.sample_rate)
            waveform = resampler(waveform)

        # Convert to mono if stereo
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Ensure waveform is float32 and normalized (pyannote expects this format)
        if waveform.dtype != torch.float32:
            waveform = waveform.float()

        # Ensure waveform is in range [-1, 1] (normalize if needed)
        if waveform.abs().max() > 1.0:
            waveform = waveform / waveform.abs().max()

        if self.pipeline is None:
            raise RuntimeError("Pyannote pipeline not initialized")

        return self._separate_with_pyannote(audio_path, waveform, num_speakers, min_speakers, max_speakers)

    def _separate_with_pyannote(
        self,
        audio_path: str,
        waveform: torch.Tensor,
        num_speakers: Optional[int],
        min_speakers: int,
        max_speakers: int,
    ) -> Dict:
        """Use pyannote.audio for speaker diarization"""
        try:
            # Use waveform dict to avoid AudioDecoder dependency issues
            # Pipeline can accept either file path or waveform dict
            # Using waveform dict is more reliable when torchcodec is not properly installed
            audio_input = {
                "waveform": waveform,
                "sample_rate": self.sample_rate,
            }

            # Run speaker diarization
            output = self.pipeline(
                audio_input,
                min_speakers=min_speakers if num_speakers is None else num_speakers,
                max_speakers=max_speakers if num_speakers is None else num_speakers,
            )

            # Extract audio segments for each speaker
            speakers_dict = defaultdict(list)
            for turn, speaker in output.speaker_diarization:
                print(f"Speaker: {speaker}, Start time: {turn.start}, End time: {turn.end}")
                start_time = turn.start
                end_time = turn.end
                start_sample = int(start_time * self.sample_rate)
                end_sample = int(end_time * self.sample_rate)

                # Extract audio segment for this time period
                segment_audio = waveform[:, start_sample:end_sample]
                speakers_dict[speaker].append((start_time, end_time, segment_audio))

            # Generate complete audio for each speaker (other speakers' segments are empty)
            speakers = []
            audio_duration = waveform.shape[1] / self.sample_rate
            num_samples = waveform.shape[1]

            for speaker_id, segments in speakers_dict.items():
                # Create zero-filled audio
                speaker_audio = torch.zeros_like(waveform)

                # Fill in this speaker's segments
                for start_time, end_time, segment_audio in segments:
                    start_sample = int(start_time * self.sample_rate)
                    end_sample = int(end_time * self.sample_rate)
                    # Ensure no out-of-bounds
                    end_sample = min(end_sample, num_samples)
                    segment_len = end_sample - start_sample
                    if segment_len > 0 and segment_audio.shape[1] > 0:
                        actual_len = min(segment_len, segment_audio.shape[1])
                        speaker_audio[:, start_sample : start_sample + actual_len] = segment_audio[:, :actual_len]

                speakers.append(
                    {
                        "speaker_id": speaker_id,
                        "audio": speaker_audio,
                        "segments": [(s[0], s[1]) for s in segments],
                        "sample_rate": self.sample_rate,
                    }
                )

            logger.info(f"Separated audio into {len(speakers)} speakers using pyannote")
            return {"speakers": speakers, "method": "pyannote"}

        except Exception as e:
            logger.error(f"Pyannote separation failed: {e}")
            raise RuntimeError(f"Audio separation failed: {e}")

    def save_speaker_audio(self, speaker_audio: torch.Tensor, output_path: str, sample_rate: int = None):
        """
        Save speaker audio to file

        Args:
            speaker_audio: Audio tensor [channels, samples]
            output_path: Output path
            sample_rate: Sample rate, if None uses self.sample_rate
        """
        sr = sample_rate if sample_rate else self.sample_rate
        torchaudio.save(output_path, speaker_audio, sr)
        logger.info(f"Saved speaker audio to {output_path}")

    def speaker_audio_to_base64(self, speaker_audio: torch.Tensor, sample_rate: int = None, format: str = "wav") -> str:
        """
        Convert speaker audio tensor to base64 encoded string without saving to file

        Args:
            speaker_audio: Audio tensor [channels, samples]
            sample_rate: Sample rate, if None uses self.sample_rate
            format: Audio format (default: "wav")

        Returns:
            Base64 encoded audio string
        """
        sr = sample_rate if sample_rate else self.sample_rate

        # Use BytesIO to save audio to memory
        buffer = io.BytesIO()
        torchaudio.save(buffer, speaker_audio, sr, format=format)

        # Get the audio bytes
        audio_bytes = buffer.getvalue()

        # Encode to base64
        audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")

        logger.debug(f"Converted speaker audio to base64, size: {len(audio_bytes)} bytes")
        return audio_base64

    def separate_and_save(
        self,
        audio_path: Union[str, bytes],
        output_dir: str,
        num_speakers: Optional[int] = None,
        min_speakers: int = 1,
        max_speakers: int = 5,
    ) -> Dict:
        """
        Separate audio and save to files

        Args:
            audio_path: Input audio path or bytes data
            output_dir: Output directory
            num_speakers: Specified number of speakers
            min_speakers: Minimum number of speakers
            max_speakers: Maximum number of speakers

        Returns:
            Separation result dictionary, containing output file paths
        """
        os.makedirs(output_dir, exist_ok=True)

        result = self.separate_speakers(audio_path, num_speakers, min_speakers, max_speakers)

        output_paths = []
        for speaker in result["speakers"]:
            speaker_id = speaker["speaker_id"]
            output_path = os.path.join(output_dir, f"{speaker_id}.wav")
            self.save_speaker_audio(speaker["audio"], output_path, speaker["sample_rate"])
            output_paths.append(output_path)
            speaker["output_path"] = output_path

        result["output_paths"] = output_paths
        return result


def separate_audio_tracks(
    audio_path: str,
    output_dir: str = None,
    num_speakers: int = None,
    model_path: str = None,
) -> Dict:
    """
    Convenience function: separate different audio tracks

    Args:
        audio_path: Audio file path
        output_dir: Output directory, if None does not save files
        num_speakers: Number of speakers
        model_path: Model path (optional)

    Returns:
        Separation result dictionary
    """
    separator = AudioSeparator(model_path=model_path)

    if output_dir:
        return separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
    else:
        return separator.separate_speakers(audio_path, num_speakers=num_speakers)


if __name__ == "__main__":
    # Test code
    import sys

    if len(sys.argv) < 2:
        print("Usage: python audio_separator.py <audio_path> [output_dir] [num_speakers]")
        sys.exit(1)

    audio_path = sys.argv[1]
    output_dir = sys.argv[2] if len(sys.argv) > 2 else "./separated_audio"
    num_speakers = int(sys.argv[3]) if len(sys.argv) > 3 else None

    separator = AudioSeparator()
    result = separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)

    print(f"Separated audio into {len(result['speakers'])} speakers:")
    for speaker in result["speakers"]:
        print(f"  Speaker {speaker['speaker_id']}: {len(speaker['segments'])} segments")
        if "output_path" in speaker:
            print(f"    Saved to: {speaker['output_path']}")