utils.py 3.66 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Audio decoding utilities for the speech-to-text endpoints."""

import io

import numpy as np
import torchaudio

from vllm.logger import init_logger
from vllm.utils.import_utils import PlaceholderModule

try:
    import librosa
except ImportError:
    librosa = PlaceholderModule("librosa")  # type: ignore[assignment]

try:
    import soundfile as sf
except ImportError:
    sf = PlaceholderModule("soundfile")  # type: ignore[assignment]

logger = init_logger(__name__)

# Public libsndfile error codes exposed via ``soundfile.LibsndfileError.code``.
# soundfile is librosa's primary backend.  These codes indicate that the audio
# data itself is problematic (unrecognised container, corrupt file, or
# unsupported encoding) rather than a transient server error.
# 1 = unrecognised format, 3 = malformed file, 4 = unsupported encoding
_BAD_SF_CODES = {1, 3, 4}


def _decode_audio_bytes_torchaudio(
    audio_data: bytes,
    sr: int,
) -> tuple[np.ndarray, int]:
    """Decode audio bytes to mono float32 PCM via torchaudio, in-process.

    ``torchaudio.load`` (backed by TorchCodec / FFmpeg) can decode
    container formats (MP4, M4A, WebM) directly from a ``BytesIO``
    buffer without spawning a subprocess.  The decoded waveform is
    down-mixed to mono and resampled to *sr* Hz, matching the return
    convention of ``librosa.load``.
    """
    buf = io.BytesIO(audio_data)
    waveform, orig_sr = torchaudio.load(buf)

    # Down-mix to mono (average across channels).
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample to the target sample rate when necessary.
    if orig_sr != sr:
        waveform = torchaudio.functional.resample(
            waveform, orig_freq=orig_sr, new_freq=sr
        )

    # Squeeze channel dim → 1-D float32 numpy array (same as librosa.load).
    y = waveform.squeeze(0).numpy()
    if y.size == 0:
        raise RuntimeError(
            "torchaudio produced no audio samples (file may be empty or corrupt)"
        )
    return y, sr


def load_audio_bytes(
    audio_data: bytes,
    sr: int | float,
) -> tuple[np.ndarray, int]:
    """Load audio from raw bytes, with an in-process torchaudio fallback.

    First tries ``librosa.load(BytesIO(...))`` which works for formats
    that *soundfile* can auto-detect (WAV, FLAC, MP3, OGG, ...).  If
    that fails with a ``LibsndfileError`` indicating an unrecognised or
    unsupported format (typically container formats like MP4/M4A/WebM),
    the bytes are decoded in-process via ``torchaudio`` (backed by
    TorchCodec / FFmpeg) which handles these containers natively.
    """
    sr = int(sr)

    # Fast path: librosa + soundfile (works for most formats).
    try:
        with io.BytesIO(audio_data) as buf:
            return librosa.load(buf, sr=sr)  # type: ignore[return-value]
    except sf.LibsndfileError as exc:
        # Only fall back for known format-detection failures.
        # Re-raise anything else (e.g. corrupt but recognised format).
        if exc.code not in _BAD_SF_CODES:
            raise
        logger.debug(
            "librosa/soundfile could not decode audio from BytesIO "
            "(code=%s: %s); falling back to torchaudio in-process decode",
            exc.code,
            exc,
        )

    # Fallback: torchaudio in-process decode (no subprocess overhead).
    try:
        return _decode_audio_bytes_torchaudio(audio_data, sr)
    except Exception as ta_exc:
        logger.debug(
            "torchaudio fallback also failed: %s",
            ta_exc,
        )
        raise ValueError("Invalid or unsupported audio file.") from ta_exc