audio.py 2.04 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
import base64
from io import BytesIO
from pathlib import Path

7
8
9
import numpy as np
import numpy.typing as npt

10
from vllm.inputs.registry import InputContext
11
from vllm.utils import PlaceholderModule
12

13
from .base import MediaIO, MultiModalPlugin
14
from .inputs import AudioItem, ModalityData, MultiModalKwargs
15

16
17
18
19
20
try:
    import librosa
except ImportError:
    librosa = PlaceholderModule("librosa")  # type: ignore[assignment]

21
22
23
24
25
try:
    import soundfile
except ImportError:
    soundfile = PlaceholderModule("soundfile")  # type: ignore[assignment]

26
27
28
29
30
31
32

class AudioPlugin(MultiModalPlugin):
    """Plugin for audio data."""

    def get_data_key(self) -> str:
        return "audio"

33
34
35
    def _default_input_mapper(
        self,
        ctx: InputContext,
36
        data: ModalityData[AudioItem],
37
38
        **mm_processor_kwargs,
    ) -> MultiModalKwargs:
39
40
41
42
43
        raise NotImplementedError("There is no default audio input mapper")

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        raise NotImplementedError(
            "There is no default maximum multimodal tokens")
44
45
46
47
48
49
50
51
52


def resample_audio(
    audio: npt.NDArray[np.floating],
    *,
    orig_sr: float,
    target_sr: float,
) -> npt.NDArray[np.floating]:
    return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77


class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):

    def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
        return librosa.load(BytesIO(data), sr=None)

    def load_base64(
        self,
        media_type: str,
        data: str,
    ) -> tuple[npt.NDArray, float]:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
        return librosa.load(filepath, sr=None)

    def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
        audio, sr = media

        with BytesIO() as buffer:
            soundfile.write(buffer, audio, sr, format="WAV")
            data = buffer.getvalue()

        return base64.b64encode(data).decode('utf-8')