Unverified Commit 5d199ac8 authored by Andrii Skliar's avatar Andrii Skliar Committed by GitHub
Browse files
parent 9e0f44be
......@@ -1056,6 +1056,7 @@ setup(
"scipy",
"soundfile",
"mistral_common[audio]",
"av",
], # Required for audio processing
"video": [], # Kept for backwards compatibility
"flashinfer": [], # Kept for backwards compatibility
......
......@@ -622,6 +622,15 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
class NemotronHNanoVLV2Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
mm_config = model_config.multimodal_config
if mm_config is not None:
video_kwargs = mm_config.media_io_kwargs.setdefault("video", {})
video_kwargs.setdefault("video_backend", "nemotron_vl")
class Qwen3_5ForConditionalGenerationConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
......@@ -661,6 +670,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"NemotronH_Nano_VL_V2": NemotronHNanoVLV2Config,
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
"LlamaNemotronVLModel": LlamaNemotronVLConfig,
......
......@@ -59,9 +59,11 @@ from vllm.multimodal.inputs import (
AudioItem,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalInputs,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.media.audio import extract_audio_from_video_bytes
from vllm.multimodal.parse import (
AudioProcessorItems,
ImageEmbeddingItems,
......@@ -69,8 +71,13 @@ from vllm.multimodal.parse import (
ImageSize,
MultiModalDataItems,
MultiModalDataParser,
VideoProcessorItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
ProcessorInputs,
TimingContext,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
......@@ -1381,6 +1388,127 @@ class NanoNemotronVLMultiModalProcessor(
):
"""MultiModalProcessor extended for video support"""
def _extract_audio_from_videos(
self,
mm_items: MultiModalDataItems,
) -> tuple[MultiModalDataItems, list[AudioItem]]:
"""Extract audio tracks from video bytes in *mm_items*.
Returns:
The augmented *mm_items* (with audio added) and the list of
extracted audio items.
"""
videos = mm_items.get_items("video", VideoProcessorItems)
assert isinstance(videos.metadata, list)
metadata_list = videos.metadata
audio_items: list[AudioItem] = []
for metadata in metadata_list:
video_bytes = metadata.get("original_video_bytes")
if video_bytes is None or len(video_bytes) == 0:
raise ValueError(
"Cannot extract audio from video: original_video_bytes is "
"missing or empty. When using use_audio_in_video=True, "
"video must be loaded with keep_video_bytes=True (e.g. via "
"the chat API with a model that sets use_audio_in_video)."
)
audio_items.append(extract_audio_from_video_bytes(video_bytes))
# Create a new VideoProcessorItems with metadata that does not contain
# the large video bytes, to avoid modifying the input `mm_items`.
new_metadata_list = [
{k: v for k, v in meta.items() if k != "original_video_bytes"}
for meta in metadata_list
]
new_videos = VideoProcessorItems(data=videos.data, metadata=new_metadata_list)
audio_parsed = self.data_parser.parse_mm_data({"audio": audio_items})
# Create a new MultiModalDataItems with the new video and audio items.
new_mm_items_dict = {**mm_items, **audio_parsed, "video": new_videos}
mm_items = MultiModalDataItems(new_mm_items_dict)
return mm_items, audio_items
def apply(
self,
processor_inputs: ProcessorInputs,
timing_ctx: TimingContext | None = None,
) -> MultiModalInputs:
if (hf_processor_mm_kwargs := processor_inputs.hf_processor_mm_kwargs) is None:
hf_processor_mm_kwargs = {}
use_audio_in_video = bool(
hf_processor_mm_kwargs.get("use_audio_in_video", False)
)
hf_processor_mm_kwargs = {
k: v for k, v in hf_processor_mm_kwargs.items() if k != "use_audio_in_video"
}
processor_inputs.hf_processor_mm_kwargs = hf_processor_mm_kwargs
if not (
use_audio_in_video
and "video" in processor_inputs.mm_data_items
and "audio" not in processor_inputs.mm_data_items
):
return super().apply(
processor_inputs,
timing_ctx,
)
mm_items, audio_items = self._extract_audio_from_videos(
processor_inputs.mm_data_items
)
processor_inputs.mm_data_items = mm_items
prompt = processor_inputs.prompt
tokenizer = self.info.get_tokenizer()
if not isinstance(prompt, str):
prompt = tokenizer.decode(prompt, skip_special_tokens=False)
for _ in audio_items:
prompt = prompt.replace("<video>", "<video>" + AUDIO_CONTEXT, 1)
processor_inputs.prompt = tokenizer.encode(prompt, add_special_tokens=False)
if processor_inputs.tokenization_kwargs is None:
processor_inputs.tokenization_kwargs = {}
# Bypass the cached path: the HF processor must receive the
# prompt (with injected <so_embedding>) and the audio data
# together so it can perform audio-token replacement natively.
(
prompt_ids,
mm_info,
is_update_applied,
) = self._apply_hf_processor(
processor_inputs,
timing_ctx=timing_ctx,
)
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes,
mm_placeholders=mm_placeholder_ranges,
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
......
......@@ -4,6 +4,7 @@ import base64
from io import BytesIO
from pathlib import Path
import numpy as np
import numpy.typing as npt
import pybase64
import torch
......@@ -23,6 +24,63 @@ try:
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
try:
import av
except ImportError:
av = PlaceholderModule("av") # type: ignore[assignment]
def extract_audio_from_video_bytes(
data: bytes,
) -> tuple[npt.NDArray, float]:
"""Extract the audio track from raw video bytes using PyAV.
PyAV wraps FFmpeg's C libraries in-process — no subprocess is
spawned, which is critical to avoid crashing CUDA-active vLLM
worker processes.
The returned waveform is at the native sample rate of the video's
audio stream. Resampling to a model-specific rate is left to the
downstream :class:`AudioResampler` in the parsing pipeline.
Args:
data: Raw video file bytes (e.g. from an mp4 file).
Returns:
A tuple of ``(waveform, sample_rate)`` suitable for use as an
:class:`AudioItem`.
"""
if data is None or len(data) == 0:
raise ValueError(
"Cannot extract audio: video bytes are missing or empty. "
"Ensure video was loaded with keep_video_bytes=True for "
"audio-in-video extraction."
)
try:
with av.open(BytesIO(data)) as container:
if not container.streams.audio:
raise ValueError("No audio stream found in the video.")
stream = container.streams.audio[0]
native_sr = stream.rate
chunks: list[npt.NDArray] = []
for frame in container.decode(audio=0):
arr = frame.to_ndarray()
chunks.append(arr.mean(axis=0) if arr.ndim > 1 else arr)
except ValueError:
raise
except Exception as e:
raise ValueError(
"Invalid or corrupted video data when extracting audio. "
"Ensure the input is valid video bytes (e.g. a complete MP4)."
) from e
if not chunks:
raise ValueError("No audio found in the video.")
audio = np.concatenate(chunks).astype(np.float32)
return audio, float(native_sr)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def __init__(self, **kwargs) -> None:
......
......@@ -749,6 +749,33 @@ class Molmo2VideoBackend(VideoLoader):
return out
@VIDEO_LOADER_REGISTRY.register("nemotron_vl")
class NemotronVLVideoBackend(OpenCVVideoBackend):
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
fps: int = -1,
max_duration: int = 300,
frame_recovery: bool = False,
**kwargs,
) -> tuple[npt.NDArray, dict[str, Any]]:
frames, metadata = OpenCVVideoBackend.load_bytes(
data,
num_frames=num_frames,
fps=fps,
max_duration=max_duration,
frame_recovery=frame_recovery,
**kwargs,
)
metadata = dict(metadata)
metadata["original_video_bytes"] = data
return frames, metadata
@VIDEO_LOADER_REGISTRY.register("openpangu")
class OpenCVDynamicOpenPanguVideoBackend(OpenCVVideoBackend):
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment