audio.py 3.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import base64
from io import BytesIO
from pathlib import Path
6
from typing import Literal
7

8
9
import numpy as np
import numpy.typing as npt
10
11
import pybase64
import torch
12

13
from vllm.utils.import_utils import PlaceholderModule
14
from vllm.utils.serial_utils import tensor2base64
15

16
from .base import MediaIO
17

18
19
20
21
22
try:
    import librosa
except ImportError:
    librosa = PlaceholderModule("librosa")  # type: ignore[assignment]

23
24
25
26
27
try:
    import soundfile
except ImportError:
    soundfile = PlaceholderModule("soundfile")  # type: ignore[assignment]

28

29
def resample_audio_librosa(
30
31
32
33
34
35
    audio: npt.NDArray[np.floating],
    *,
    orig_sr: float,
    target_sr: float,
) -> npt.NDArray[np.floating]:
    return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
36
37


38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def resample_audio_scipy(
    audio: npt.NDArray[np.floating],
    *,
    orig_sr: float,
    target_sr: float,
):
    # lazy import scipy.signal, otherwise it will crash doc build.
    import scipy.signal

    if orig_sr > target_sr:
        return scipy.signal.resample_poly(audio, 1, orig_sr // target_sr)
    elif orig_sr < target_sr:
        return scipy.signal.resample_poly(audio, target_sr // orig_sr, 1)
    return audio


class AudioResampler:
    """Resample audio data to a target sample rate."""

    def __init__(
        self,
59
        target_sr: float | None = None,
60
61
62
63
64
65
66
67
68
69
70
71
        method: Literal["librosa", "scipy"] = "librosa",
    ):
        self.target_sr = target_sr
        self.method = method

    def resample(
        self,
        audio: npt.NDArray[np.floating],
        *,
        orig_sr: float,
    ) -> npt.NDArray[np.floating]:
        if self.target_sr is None:
72
73
74
            raise RuntimeError(
                "Audio resampling is not supported when `target_sr` is not provided"
            )
75
        if self.method == "librosa":
76
77
78
            return resample_audio_librosa(
                audio, orig_sr=orig_sr, target_sr=self.target_sr
            )
79
        elif self.method == "scipy":
80
81
82
            return resample_audio_scipy(
                audio, orig_sr=orig_sr, target_sr=self.target_sr
            )
83
        else:
84
85
86
87
            raise ValueError(
                f"Invalid resampling method: {self.method}. "
                "Supported methods are 'librosa' and 'scipy'."
            )
88
89


90
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
91
92
93
94
95
96
97
98
99
100
    def __init__(self, **kwargs) -> None:
        super().__init__()

        # `kwargs` contains custom arguments from
        # --media-io-kwargs for this modality.
        # They can be passed to the underlying
        # media loaders (e.g. custom implementations)
        # for flexible control.
        self.kwargs = kwargs

101
102
103
104
105
106
107
108
109
110
111
112
113
    def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
        return librosa.load(BytesIO(data), sr=None)

    def load_base64(
        self,
        media_type: str,
        data: str,
    ) -> tuple[npt.NDArray, float]:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
        return librosa.load(filepath, sr=None)

114
    def encode_base64(self, media: tuple[npt.NDArray, int]) -> str:
115
116
117
118
119
120
        audio, sr = media

        with BytesIO() as buffer:
            soundfile.write(buffer, audio, sr, format="WAV")
            data = buffer.getvalue()

121
        return base64.b64encode(data).decode("utf-8")
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
    def __init__(self) -> None:
        super().__init__()

    def load_bytes(self, data: bytes) -> torch.Tensor:
        buffer = BytesIO(data)
        return torch.load(buffer, weights_only=True)

    def load_base64(self, media_type: str, data: str) -> torch.Tensor:
        return self.load_bytes(pybase64.b64decode(data, validate=True))

    def load_file(self, filepath: Path) -> torch.Tensor:
        return torch.load(filepath, weights_only=True)

    def encode_base64(self, media: torch.Tensor) -> str:
139
        return tensor2base64(media)