Unverified Commit 84868e47 authored by seanmamasde's avatar seanmamasde Committed by GitHub
Browse files

[Bugfix][Frontend] Fix audio transcription for MP4, M4A, and WebM formats (#35109)


Signed-off-by: default avatarseanmamasde <seanmamasde@gmail.com>
parent a8e8d62d
...@@ -976,6 +976,7 @@ setup( ...@@ -976,6 +976,7 @@ setup(
"soundfile", "soundfile",
"mistral_common[audio]", "mistral_common[audio]",
"av", "av",
"torchcodec",
], # Required for audio processing ], # Required for audio processing
"video": [], # Kept for backwards compatibility "video": [], # Kept for backwards compatibility
"flashinfer": [], # Kept for backwards compatibility "flashinfer": [], # Kept for backwards compatibility
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import io
import math import math
import time import time
import zlib import zlib
...@@ -11,7 +10,6 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast ...@@ -11,7 +10,6 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np import numpy as np
from fastapi import Request from fastapi import Request
from soundfile import LibsndfileError
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
import vllm.envs as envs import vllm.envs as envs
...@@ -37,6 +35,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -37,6 +35,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationSegment, TranslationSegment,
TranslationStreamResponse, TranslationStreamResponse,
) )
from vllm.entrypoints.openai.speech_to_text.utils import load_audio_bytes
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
...@@ -56,14 +55,6 @@ try: ...@@ -56,14 +55,6 @@ try:
except ImportError: except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment] librosa = PlaceholderModule("librosa") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = ( SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose TranscriptionResponseVerbose | TranslationResponseVerbose
...@@ -202,16 +193,12 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -202,16 +193,12 @@ class OpenAISpeechToText(OpenAIServing):
value=len(audio_data) / 1024**2, value=len(audio_data) / 1024**2,
) )
with io.BytesIO(audio_data) as bytes_: # Decode audio bytes. For container formats (MP4, M4A, WebM) that
try: # soundfile cannot detect from a BytesIO stream, _load_audio_bytes
# transparently falls back to ffmpeg via an in-memory fd.
# NOTE resample to model SR here for efficiency. This is also a # NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR. # pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) y, sr = load_audio_bytes(audio_data, sr=self.asr_config.sample_rate)
except LibsndfileError as exc:
# Distinguish client errors (invalid audio) from server errors
if exc.code in _BAD_SF_CODES:
raise ValueError("Invalid or unsupported audio file.") from exc
raise
duration = librosa.get_duration(y=y, sr=sr) duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = ( do_split_audio = (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Audio decoding utilities for the speech-to-text endpoints."""
import io
import numpy as np
import torchaudio
from vllm.logger import init_logger
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile as sf
except ImportError:
sf = PlaceholderModule("soundfile") # type: ignore[assignment]
logger = init_logger(__name__)
# Public libsndfile error codes exposed via ``soundfile.LibsndfileError.code``.
# soundfile is librosa's primary backend. These codes indicate that the audio
# data itself is problematic (unrecognised container, corrupt file, or
# unsupported encoding) rather than a transient server error.
# 1 = unrecognised format, 3 = malformed file, 4 = unsupported encoding
_BAD_SF_CODES = {1, 3, 4}
def _decode_audio_bytes_torchaudio(
audio_data: bytes,
sr: int,
) -> tuple[np.ndarray, int]:
"""Decode audio bytes to mono float32 PCM via torchaudio, in-process.
``torchaudio.load`` (backed by TorchCodec / FFmpeg) can decode
container formats (MP4, M4A, WebM) directly from a ``BytesIO``
buffer without spawning a subprocess. The decoded waveform is
down-mixed to mono and resampled to *sr* Hz, matching the return
convention of ``librosa.load``.
"""
buf = io.BytesIO(audio_data)
waveform, orig_sr = torchaudio.load(buf)
# Down-mix to mono (average across channels).
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample to the target sample rate when necessary.
if orig_sr != sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=orig_sr, new_freq=sr
)
# Squeeze channel dim → 1-D float32 numpy array (same as librosa.load).
y = waveform.squeeze(0).numpy()
if y.size == 0:
raise RuntimeError(
"torchaudio produced no audio samples (file may be empty or corrupt)"
)
return y, sr
def load_audio_bytes(
audio_data: bytes,
sr: int | float,
) -> tuple[np.ndarray, int]:
"""Load audio from raw bytes, with an in-process torchaudio fallback.
First tries ``librosa.load(BytesIO(...))`` which works for formats
that *soundfile* can auto-detect (WAV, FLAC, MP3, OGG, ...). If
that fails with a ``LibsndfileError`` indicating an unrecognised or
unsupported format (typically container formats like MP4/M4A/WebM),
the bytes are decoded in-process via ``torchaudio`` (backed by
TorchCodec / FFmpeg) which handles these containers natively.
"""
sr = int(sr)
# Fast path: librosa + soundfile (works for most formats).
try:
with io.BytesIO(audio_data) as buf:
return librosa.load(buf, sr=sr) # type: ignore[return-value]
except sf.LibsndfileError as exc:
# Only fall back for known format-detection failures.
# Re-raise anything else (e.g. corrupt but recognised format).
if exc.code not in _BAD_SF_CODES:
raise
logger.debug(
"librosa/soundfile could not decode audio from BytesIO "
"(code=%s: %s); falling back to torchaudio in-process decode",
exc.code,
exc,
)
# Fallback: torchaudio in-process decode (no subprocess overhead).
try:
return _decode_audio_bytes_torchaudio(audio_data, sr)
except Exception as ta_exc:
logger.debug(
"torchaudio fallback also failed: %s",
ta_exc,
)
raise ValueError("Invalid or unsupported audio file.") from ta_exc
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment