Unverified Commit c46b0cd0 authored by Wang Haoyu's avatar Wang Haoyu Committed by GitHub
Browse files

[Model][Multimodal] Add explicit MusicFlamingo adapter (#32696)


Signed-off-by: default avatarWangHaoyuuu <mailwhaoyu@gmail.com>
parent 13376576
......@@ -657,7 +657,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A<sup>+</sup> | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-2601-hf` | ✅︎ | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
......
......@@ -70,6 +70,34 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
)
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/music-flamingo-2601-hf"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
# MusicFlamingo uses <sound> token for audio
audio_placeholder = "<sound>" * audio_count
prompt = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_placeholder}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
model_name = "google/gemma-3n-E2B-it"
......@@ -452,6 +480,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"audioflamingo3": run_audioflamingo3,
"musicflamingo": run_musicflamingo,
"gemma3n": run_gemma3n,
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
......
......@@ -657,6 +657,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
),
"MusicFlamingoForConditionalGeneration": _HfExamplesInfo(
"nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev"
),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
"BeeForConditionalGeneration": _HfExamplesInfo(
......
......@@ -128,6 +128,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
# Keep a dummy freqs parameter for MusicFlamingo checkpoints.
self.pos_emb = nn.Module()
freqs = torch.empty(getattr(config, "num_mel_bins", 128))
self.pos_emb.register_parameter(
"freqs", nn.Parameter(freqs, requires_grad=False)
)
def forward(
self,
......@@ -146,7 +152,8 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
).to(hidden_states.dtype)
for layer in self.layers:
layer_outputs = layer(hidden_states, attention_mask)
# Qwen2AudioEncoderLayer expects layer_head_mask as third arg.
layer_outputs = layer(hidden_states, attention_mask, None)
hidden_states = layer_outputs[0]
# AvgPool (time/2) + LayerNorm
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""MusicFlamingo model adapter.
MusicFlamingo shares the AudioFlamingo3 architecture, so we reuse the same
implementation and multimodal processor, while accepting MusicFlamingo config
and processor classes when available.
"""
from collections.abc import Mapping
from transformers.models.audioflamingo3 import (
AudioFlamingo3Config,
AudioFlamingo3Processor,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import BaseProcessingInfo
from .audioflamingo3 import (
AudioFlamingo3DummyInputsBuilder,
AudioFlamingo3ForConditionalGeneration,
AudioFlamingo3MultiModalProcessor,
)
try:
# Optional dependency: use MusicFlamingo classes when transformers provides them.
from transformers.models.musicflamingo import (
MusicFlamingoConfig,
MusicFlamingoProcessor,
)
except Exception: # pragma: no cover - optional dependency
MusicFlamingoConfig = None
MusicFlamingoProcessor = None
class MusicFlamingoProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
if MusicFlamingoConfig is None:
return self.ctx.get_hf_config(AudioFlamingo3Config)
return self.ctx.get_hf_config((MusicFlamingoConfig, AudioFlamingo3Config))
def get_hf_processor(self, **kwargs: object):
if MusicFlamingoProcessor is None:
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
# Tuple triggers AutoProcessor path and accepts either processor class.
return self.ctx.get_hf_processor(
(MusicFlamingoProcessor, AudioFlamingo3Processor), **kwargs
)
def get_feature_extractor(self, **kwargs: object):
hf_processor = self.get_hf_processor(**kwargs)
return hf_processor.feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
pass
@MULTIMODAL_REGISTRY.register_processor(
AudioFlamingo3MultiModalProcessor,
info=MusicFlamingoProcessingInfo,
dummy_inputs=MusicFlamingoDummyInputsBuilder,
)
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
"""MusicFlamingo model for conditional generation."""
......@@ -286,6 +286,10 @@ _MULTIMODAL_MODELS = {
"audioflamingo3",
"AudioFlamingo3ForConditionalGeneration",
),
"MusicFlamingoForConditionalGeneration": (
"musicflamingo",
"MusicFlamingoForConditionalGeneration",
),
"AyaVisionForConditionalGeneration": (
"aya_vision",
"AyaVisionForConditionalGeneration",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment