Unverified Commit e7767ecc authored by Lasha Koroshinadze's avatar Lasha Koroshinadze Committed by GitHub
Browse files

Fix AudioFlamingo3/MusicFlamingo HF parity and RoTE handling (#37643)


Signed-off-by: default avatarLasha <26011196+lashahub@users.noreply.github.com>
parent 43877a62
...@@ -535,7 +535,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -535,7 +535,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
| ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- | | ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- |
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | | `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | |
| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-2601-hf` | ✅︎ | ✅︎ | | `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ | | `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
| `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | `BagelForConditionalGeneration` | BAGEL | T + I<sup>+</sup> | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ | | `BeeForConditionalGeneration` | Bee-8B | T + I<sup>E+</sup> | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
...@@ -586,6 +586,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen ...@@ -586,6 +586,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ |
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | | `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ |
| `Molmo2ForConditionalGeneration` | Molmo2 | T + I<sup>+</sup> / V | `allenai/Molmo2-4B`, `allenai/Molmo2-8B`, `allenai/Molmo2-O-7B` | ✅︎ | ✅︎ | | `Molmo2ForConditionalGeneration` | Molmo2 | T + I<sup>+</sup> / V | `allenai/Molmo2-4B`, `allenai/Molmo2-8B`, `allenai/Molmo2-O-7B` | ✅︎ | ✅︎ |
| `MusicFlamingoForConditionalGeneration` | MusicFlamingo | T + A | `nvidia/music-flamingo-2601-hf`, `nvidia/music-flamingo-think-2601-hf` | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ |
| `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + I<sup>E+</sup> | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ | | `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + I<sup>E+</sup> | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ |
| `OpenPanguVLForConditionalGeneration` | openpangu-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `FreedomIntelligence/openPangu-VL-7B` | ✅︎ | ✅︎ | | `OpenPanguVLForConditionalGeneration` | openpangu-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `FreedomIntelligence/openPangu-VL-7B` | ✅︎ | ✅︎ |
......
...@@ -104,12 +104,22 @@ def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData: ...@@ -104,12 +104,22 @@ def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
enforce_eager=True, enforce_eager=True,
) )
# MusicFlamingo uses <sound> token for audio # MusicFlamingo prompt placeholders use <sound>; vLLM's MusicFlamingo
# multimodal processor expands each one into <|sound_bos|> + audio tokens +
# <|sound_eos|> based on extracted audio feature lengths.
audio_placeholder = "<sound>" * audio_count audio_placeholder = "<sound>" * audio_count
system_prompt = (
"You are Music Flamingo, a multimodal assistant for language and music. "
"On each turn you receive an audio clip which contains music and optional "
"text, you will receive at least one or both; use your world knowledge and "
"reasoning to help the user with any task. Interpret the entirety of the "
"content any input music--regardlenss of whether the user calls it audio, "
"music, or sound."
)
prompt = ( prompt = (
"<|im_start|>system\n" "<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n" f"{system_prompt}<|im_end|>\n"
"<|im_start|>user\n" "<|im_start|>user\n"
f"{audio_placeholder}{question}<|im_end|>\n" f"{audio_placeholder}{question}<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
......
{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]} {"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other."], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645]]}
\ No newline at end of file
{"transcriptions": ["This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club‑ready pulse of classic Eurodance. The duration of the piece is ", "**Verse 1**\nMidnight cravings in bloom, lights flicker in the room, pepperoni dreams arise, pizza party on your skies\n\n**Verse 2**\nCheese melts on the crust, in flavor we trust, boxes stacked to the"], "token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220], [334, 68043, 220, 16, 1019, 33648, 9287, 88828, 304, 51454, 11, 12711, 28347, 261, 304, 279, 3054, 11, 24353, 20783, 18707, 30789, 11, 22502, 4614, 389, 697, 49293, 271, 334, 68043, 220, 17, 1019, 26843, 2367, 98091, 389, 279, 39612, 11, 304, 17172, 582, 6950, 11, 14697, 41315, 311, 279]]}
{"transcriptions": ["This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club‑ready pulse of classic Eurodance. The duration of the piece is "], "token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220]]}
...@@ -26,6 +26,54 @@ from tests.models.registry import HF_EXAMPLE_MODELS ...@@ -26,6 +26,54 @@ from tests.models.registry import HF_EXAMPLE_MODELS
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
MODEL_NAME = "nvidia/audio-flamingo-3-hf" MODEL_NAME = "nvidia/audio-flamingo-3-hf"
SINGLE_CONVERSATION = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What is surprising about the relationship between "
"the barking and the music?",
},
{
"type": "audio_url",
"audio_url": {
"url": "https://huggingface.co/datasets/nvidia/AudioSkills/"
"resolve/main/assets/"
"dogs_barking_in_sync_with_the_music.wav",
},
},
],
}
]
BATCHED_CONVERSATIONS = [
SINGLE_CONVERSATION,
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Why is the philosopher's name mentioned in the "
"lyrics? (A) To express a sense of nostalgia "
"(B) To indicate that language cannot express clearly, "
"satirizing the inversion of black and white in the world "
"(C) To add depth and complexity to the lyrics "
"(D) To showcase the wisdom and influence of the "
"philosopher",
},
{
"type": "audio_url",
"audio_url": {
"url": "https://huggingface.co/datasets/nvidia/"
"AudioSkills/resolve/main/assets/"
"Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
},
},
],
}
],
]
def get_fixture_path(filename): def get_fixture_path(filename):
...@@ -34,21 +82,29 @@ def get_fixture_path(filename): ...@@ -34,21 +82,29 @@ def get_fixture_path(filename):
) )
def assert_output_matches(output, expected_text, expected_token_ids):
generated = output.outputs[0]
assert generated.text.strip() == expected_text
actual_token_ids = list(generated.token_ids)
assert (
actual_token_ids == expected_token_ids
or actual_token_ids == expected_token_ids[:-1]
or actual_token_ids[:-1] == expected_token_ids
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# Check if the model is supported by the current transformers version
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
try: try:
llm = LLM( return LLM(
model=MODEL_NAME, model=MODEL_NAME,
trust_remote_code=True,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=True, enforce_eager=True,
limit_mm_per_prompt={"audio": 1}, limit_mm_per_prompt={"audio": 1},
) )
return llm
except Exception as e: except Exception as e:
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}") pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
...@@ -61,29 +117,17 @@ def test_single_generation(llm): ...@@ -61,29 +117,17 @@ def test_single_generation(llm):
with open(fixture_path) as f: with open(fixture_path) as f:
expected = json.load(f) expected = json.load(f)
audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav"
messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": audio_url}},
{"type": "text", "text": "Transcribe the input speech."},
],
}
]
sampling_params = SamplingParams(temperature=0.0, max_tokens=128) sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
outputs = llm.chat( outputs = llm.chat(
messages=messages, messages=SINGLE_CONVERSATION,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
generated_text = outputs[0].outputs[0].text.strip() assert_output_matches(
outputs[0],
expected_text = expected["transcriptions"][0] expected["transcriptions"][0],
expected["token_ids"][0],
assert expected_text in generated_text or generated_text in expected_text )
def test_batched_generation(llm): def test_batched_generation(llm):
...@@ -94,49 +138,34 @@ def test_batched_generation(llm): ...@@ -94,49 +138,34 @@ def test_batched_generation(llm):
with open(fixture_path) as f: with open(fixture_path) as f:
expected = json.load(f) expected = json.load(f)
items = [
{
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav",
"question": "What is surprising about the relationship "
"between the barking and the music?",
"expected_idx": 0,
},
{
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
"question": (
"Why is the philosopher's name mentioned in the lyrics? "
"(A) To express a sense of nostalgia "
"(B) To indicate that language cannot express clearly, "
"satirizing the inversion of black and white in the world "
"(C) To add depth and complexity to the lyrics "
"(D) To showcase the wisdom and influence of the philosopher"
),
"expected_idx": 1,
},
]
conversations = []
for item in items:
messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": item["audio_url"]}},
{"type": "text", "text": item["question"]},
],
}
]
conversations.append(messages)
sampling_params = SamplingParams(temperature=0.0, max_tokens=128) sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
outputs = llm.chat( outputs = llm.chat(
messages=conversations, messages=BATCHED_CONVERSATIONS,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
generated_text = output.outputs[0].text.strip() assert_output_matches(
expected_text = expected["transcriptions"][i] output,
expected["transcriptions"][i],
expected["token_ids"][i],
)
def test_single_and_batched_generation_match(llm):
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
assert expected_text in generated_text or generated_text in expected_text single_output = llm.chat(
messages=SINGLE_CONVERSATION,
sampling_params=sampling_params,
)[0]
batched_output = llm.chat(
messages=BATCHED_CONVERSATIONS,
sampling_params=sampling_params,
)[0]
assert single_output.outputs[0].text == batched_output.outputs[0].text
assert list(single_output.outputs[0].token_ids) == list(
batched_output.outputs[0].token_ids
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
from vllm import LLM, SamplingParams
MODEL_NAME = "nvidia/music-flamingo-2601-hf"
SINGLE_CONVERSATION = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe this track in full detail - tell me the "
"genre, tempo, and key, then dive into the instruments, "
"production style, and overall mood it creates.",
},
{
"type": "audio_url",
"audio_url": {
"url": "https://huggingface.co/datasets/nvidia/AudioSkills/"
"resolve/main/assets/song_1.mp3",
},
},
],
}
]
BATCHED_CONVERSATIONS = [
SINGLE_CONVERSATION,
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Generate a structured lyric sheet from the input music.",
},
{
"type": "audio_url",
"audio_url": {
"url": "https://huggingface.co/datasets/nvidia/"
"AudioSkills/resolve/main/assets/song_2.mp3",
},
},
],
}
],
]
def get_fixture_path(filename):
return os.path.join(
os.path.dirname(__file__), "../../fixtures/musicflamingo", filename
)
def assert_output_matches(output, expected_text, expected_token_ids):
generated = output.outputs[0]
assert generated.text == expected_text
actual_token_ids = list(generated.token_ids)
assert (
actual_token_ids == expected_token_ids
or actual_token_ids == expected_token_ids[:-1]
or actual_token_ids[:-1] == expected_token_ids
)
@pytest.fixture(scope="module")
def llm():
model_info = HF_EXAMPLE_MODELS.get_hf_info("MusicFlamingoForConditionalGeneration")
model_info.check_transformers_version(on_fail="skip")
try:
return LLM(
model=MODEL_NAME,
dtype="bfloat16",
enforce_eager=True,
max_model_len=8192,
limit_mm_per_prompt={"audio": 1},
)
except Exception as e:
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
def test_single_generation(llm):
fixture_path = get_fixture_path("expected_results_single.json")
if not os.path.exists(fixture_path):
pytest.skip(f"Fixture not found: {fixture_path}")
with open(fixture_path) as f:
expected = json.load(f)
outputs = llm.chat(
messages=SINGLE_CONVERSATION,
sampling_params=SamplingParams(temperature=0.0, max_tokens=50),
)
assert_output_matches(
outputs[0],
expected["transcriptions"][0],
expected["token_ids"][0],
)
def test_batched_generation(llm):
fixture_path = get_fixture_path("expected_results_batched.json")
if not os.path.exists(fixture_path):
pytest.skip(f"Fixture not found: {fixture_path}")
with open(fixture_path) as f:
expected = json.load(f)
outputs = llm.chat(
messages=BATCHED_CONVERSATIONS,
sampling_params=SamplingParams(temperature=0.0, max_tokens=50),
)
for i, output in enumerate(outputs):
assert_output_matches(
output,
expected["transcriptions"][i],
expected["token_ids"][i],
)
def test_single_and_batched_generation_match(llm):
sampling_params = SamplingParams(temperature=0.0, max_tokens=50)
single_output = llm.chat(
messages=SINGLE_CONVERSATION,
sampling_params=sampling_params,
)[0]
batched_output = llm.chat(
messages=BATCHED_CONVERSATIONS,
sampling_params=sampling_params,
)[0]
assert single_output.outputs[0].text == batched_output.outputs[0].text
assert list(single_output.outputs[0].token_ids) == list(
batched_output.outputs[0].token_ids
)
...@@ -40,6 +40,7 @@ class MockAudioFlamingo3Processor: ...@@ -40,6 +40,7 @@ class MockAudioFlamingo3Processor:
def __init__(self): def __init__(self):
self.audio_token = "<sound>" self.audio_token = "<sound>"
self.audio_token_id = 12345 self.audio_token_id = 12345
self.max_audio_len = 60
self.feature_extractor = MockFeatureExtractor() self.feature_extractor = MockFeatureExtractor()
def __call__(self, text=None, audios=None, **kwargs): def __call__(self, text=None, audios=None, **kwargs):
...@@ -65,7 +66,6 @@ def mock_ctx(): ...@@ -65,7 +66,6 @@ def mock_ctx():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def check_transformers_version(): def check_transformers_version():
# Check if the model is supported by the current transformers version
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
model_info.check_transformers_version(on_fail="skip") model_info.check_transformers_version(on_fail="skip")
...@@ -84,7 +84,7 @@ def test_audio_chunk_counting(mock_ctx): ...@@ -84,7 +84,7 @@ def test_audio_chunk_counting(mock_ctx):
sr = 16000 sr = 16000
audio_1 = np.zeros(30 * sr) audio_1 = np.zeros(30 * sr)
audio_2 = np.zeros(45 * sr) audio_2 = np.zeros(75 * sr)
mm_data = {"audio": [audio_1, audio_2]} mm_data = {"audio": [audio_1, audio_2]}
prompt = "<|user|>Listen.<|end|>" prompt = "<|user|>Listen.<|end|>"
...@@ -121,5 +121,107 @@ def test_dummy_data_generation(mock_ctx): ...@@ -121,5 +121,107 @@ def test_dummy_data_generation(mock_ctx):
assert "audio" in dummy_data assert "audio" in dummy_data
assert len(dummy_data["audio"]) == 2 assert len(dummy_data["audio"]) == 2
expected_len = 600 * 16000 expected_len = 60 * 16000
assert len(dummy_data["audio"][0]) == expected_len assert len(dummy_data["audio"][0]) == expected_len
def test_audio_token_count_matches_hf_processor_math():
from vllm.model_executor.models.audioflamingo3 import (
_count_audio_tokens_from_mask,
)
feature_attention_mask = torch.zeros((3, 3000), dtype=torch.long)
feature_attention_mask[0, :2999] = 1
feature_attention_mask[1, :2999] = 1
feature_attention_mask[2, :1500] = 1
chunk_counts = torch.tensor([2, 1], dtype=torch.long)
assert (
_count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 0) == 1499
)
assert _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 1) == 375
def test_audio_feature_pipeline_matches_hf_small_config():
from transformers.models.audioflamingo3 import (
modeling_audioflamingo3 as hf_audioflamingo3_modeling,
)
from transformers.models.audioflamingo3.configuration_audioflamingo3 import (
AudioFlamingo3Config,
)
from vllm.model_executor.models.audioflamingo3 import (
AudioFlamingo3Encoder,
AudioFlamingo3MultiModalProjector,
_build_audio_encoder_attention_mask,
_flatten_valid_audio_embeddings,
)
text_config = {
"model_type": "qwen2",
"intermediate_size": 64,
"initializer_range": 0.02,
"hidden_size": 32,
"max_position_embeddings": 1024,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"vocab_size": 128,
"pad_token_id": 1,
"use_mrope": False,
}
audio_config = {
"hidden_size": 16,
"num_attention_heads": 4,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_mel_bins": 80,
"max_source_positions": 1500,
"dropout": 0.0,
"attention_dropout": 0.0,
"activation_dropout": 0.0,
"encoder_layerdrop": 0.0,
}
torch.manual_seed(0)
config = AudioFlamingo3Config(
text_config=text_config,
audio_config=audio_config,
audio_token_id=0,
)
hf_model = hf_audioflamingo3_modeling.AudioFlamingo3ForConditionalGeneration(
config
).eval()
vllm_encoder = AudioFlamingo3Encoder(config.audio_config).eval()
vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict())
vllm_projector = AudioFlamingo3MultiModalProjector(config).eval()
vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict())
input_features = torch.randn(3, 80, 3000)
feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool)
feature_attention_mask[0, :3000] = True
feature_attention_mask[1, :2500] = True
feature_attention_mask[2, :1500] = True
hf_output = hf_model.get_audio_features(
input_features,
feature_attention_mask,
return_dict=True,
).pooler_output
vllm_attention_mask = _build_audio_encoder_attention_mask(
feature_attention_mask,
dtype=vllm_encoder.conv1.weight.dtype,
device=vllm_encoder.conv1.weight.device,
)
vllm_hidden_states = vllm_encoder(
input_features,
attention_mask=vllm_attention_mask,
)
vllm_output, _ = _flatten_valid_audio_embeddings(
vllm_projector(vllm_hidden_states),
feature_attention_mask,
)
torch.testing.assert_close(vllm_output, hf_output)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2026 The vLLM team.
# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
from transformers import PretrainedConfig
from tests.models.registry import HF_EXAMPLE_MODELS
class MockMusicFlamingoConfig(PretrainedConfig):
model_type = "musicflamingo"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.audio_config = PretrainedConfig()
self.text_config = PretrainedConfig()
class MockMusicFlamingoProcessor:
def __init__(self):
self.audio_token = "<sound>"
self.audio_token_id = 12345
self.audio_bos_token = "<|sound_bos|>"
self.audio_bos_token_id = 12346
self.audio_eos_token = "<|sound_eos|>"
self.audio_eos_token_id = 12347
self.max_audio_len = 1200
self.feature_extractor = MockFeatureExtractor()
class MockFeatureExtractor:
def __init__(self):
self.sampling_rate = 16000
self.chunk_length = 30
@pytest.fixture
def mock_ctx():
config = MockMusicFlamingoConfig()
ctx = MagicMock()
ctx.get_hf_config.return_value = config
ctx.get_hf_processor.return_value = MockMusicFlamingoProcessor()
ctx.model_config.hf_config = config
return ctx
@pytest.fixture(autouse=True)
def check_transformers_version():
model_info = HF_EXAMPLE_MODELS.get_hf_info("MusicFlamingoForConditionalGeneration")
model_info.check_transformers_version(on_fail="skip")
def test_musicflamingo_chunk_counting_uses_rote_timestamps(mock_ctx, monkeypatch):
from vllm.model_executor.models.musicflamingo import (
MusicFlamingoDummyInputsBuilder,
MusicFlamingoMultiModalProcessor,
MusicFlamingoProcessingInfo,
)
info = MusicFlamingoProcessingInfo(mock_ctx)
processor = MusicFlamingoMultiModalProcessor(
info, MusicFlamingoDummyInputsBuilder(info)
)
sr = 16000
audio_1 = np.zeros(30 * sr)
audio_2 = np.zeros(45 * sr)
mm_data = {"audio": [audio_1, audio_2]}
prompt = "<|user|>Listen.<|end|>"
from vllm.multimodal.processing import BaseMultiModalProcessor
def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs):
del self, prompt, mm_data, mm_kwargs, tok_kwargs
return {
"input_ids": [1, 2, 3],
"input_features": torch.randn(3, 80, 3000),
"rote_timestamps": torch.randn(3, 750),
}
monkeypatch.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call)
processed = processor._call_hf_processor(prompt, mm_data, {}, {})
chunk_counts = processed["chunk_counts"]
assert chunk_counts.tolist() == [1, 2]
assert "rote_timestamps" in processed
def test_musicflamingo_dummy_text_uses_plain_audio_tokens(mock_ctx):
from vllm.model_executor.models.musicflamingo import (
MusicFlamingoDummyInputsBuilder,
MusicFlamingoProcessingInfo,
)
info = MusicFlamingoProcessingInfo(mock_ctx)
builder = MusicFlamingoDummyInputsBuilder(info)
assert builder.get_dummy_text({"audio": 2}) == "<sound><sound>"
def test_musicflamingo_audio_feature_pipeline_matches_hf_small_config():
from transformers.models.musicflamingo import (
modeling_musicflamingo as hf_musicflamingo_modeling,
)
from transformers.models.musicflamingo.configuration_musicflamingo import (
MusicFlamingoConfig,
)
from vllm.model_executor.models.audioflamingo3 import (
_build_audio_encoder_attention_mask,
_flatten_valid_audio_embeddings,
)
from vllm.model_executor.models.musicflamingo import (
MusicFlamingoEncoder,
MusicFlamingoMultiModalProjector,
MusicFlamingoRotaryEmbedding,
apply_rotary_time_emb,
)
text_config = {
"model_type": "qwen2",
"intermediate_size": 64,
"initializer_range": 0.02,
"hidden_size": 32,
"max_position_embeddings": 1024,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"vocab_size": 128,
"pad_token_id": 1,
"use_mrope": False,
}
audio_config = {
"hidden_size": 16,
"num_attention_heads": 4,
"intermediate_size": 32,
"num_hidden_layers": 2,
"num_mel_bins": 80,
"max_source_positions": 1500,
"dropout": 0.0,
"attention_dropout": 0.0,
"activation_dropout": 0.0,
"encoder_layerdrop": 0.0,
}
torch.manual_seed(0)
config = MusicFlamingoConfig(
text_config=text_config,
audio_config=audio_config,
audio_token_id=0,
head_dim=8,
rope_parameters={"rope_type": "default", "rope_theta": 2048},
)
hf_model = hf_musicflamingo_modeling.MusicFlamingoForConditionalGeneration(
config
).eval()
vllm_encoder = MusicFlamingoEncoder(config.audio_config).eval()
vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict())
vllm_projector = MusicFlamingoMultiModalProjector(config).eval()
vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict())
vllm_rope = MusicFlamingoRotaryEmbedding(config).eval()
vllm_rope.load_state_dict(hf_model.pos_emb.state_dict(), strict=False)
input_features = torch.randn(3, 80, 3000)
feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool)
feature_attention_mask[0, :3000] = True
feature_attention_mask[1, :2500] = True
feature_attention_mask[2, :1500] = True
rote_timestamps = (
torch.arange(750, dtype=torch.float32).unsqueeze(0).repeat(3, 1) * 0.04
)
hf_output = hf_model.get_audio_features(
input_features,
feature_attention_mask,
rote_timestamps=rote_timestamps,
return_dict=True,
).pooler_output
vllm_attention_mask = _build_audio_encoder_attention_mask(
feature_attention_mask,
dtype=vllm_encoder.conv1.weight.dtype,
device=vllm_encoder.conv1.weight.device,
)
vllm_hidden_states = vllm_encoder(
input_features,
attention_mask=vllm_attention_mask,
)
cos, sin = vllm_rope(rote_timestamps, seq_len=vllm_hidden_states.shape[-2])
vllm_hidden_states = apply_rotary_time_emb(vllm_hidden_states, cos, sin)
vllm_output, _ = _flatten_valid_audio_embeddings(
vllm_projector(vllm_hidden_states),
feature_attention_mask,
)
torch.testing.assert_close(vllm_output, hf_output)
...@@ -752,7 +752,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -752,7 +752,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0" "nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
), ),
"MusicFlamingoForConditionalGeneration": _HfExamplesInfo( "MusicFlamingoForConditionalGeneration": _HfExamplesInfo(
"nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev" "nvidia/music-flamingo-2601-hf", min_transformers_version="5.3.0"
), ),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"), "BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
......
...@@ -69,10 +69,7 @@ from .utils import ( ...@@ -69,10 +69,7 @@ from .utils import (
maybe_prefix, maybe_prefix,
) )
MAX_AUDIO_LEN = 10 * 60
# === Audio Inputs === #
class AudioFlamingo3FeatureInputs(TensorSchema): class AudioFlamingo3FeatureInputs(TensorSchema):
""" """
Dimensions: Dimensions:
...@@ -127,14 +124,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder): ...@@ -127,14 +124,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
): ):
super().__init__(config) super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2) self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
def forward( def forward(
self, self,
input_features: torch.Tensor | list[torch.Tensor], input_features: torch.Tensor | list[torch.Tensor],
attention_mask: torch.Tensor = None, attention_mask: torch.Tensor = None,
): ):
# input_features: (batch, num_mel_bins, seq_len)
if isinstance(input_features, list): if isinstance(input_features, list):
input_features = torch.stack(input_features) input_features = torch.stack(input_features)
...@@ -146,17 +141,14 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder): ...@@ -146,17 +141,14 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
).to(hidden_states.dtype) ).to(hidden_states.dtype)
for layer in self.layers: for layer in self.layers:
# Qwen2AudioEncoderLayer expects layer_head_mask as third arg. layer_outputs = layer(hidden_states, attention_mask)
layer_outputs = layer(hidden_states, attention_mask, None) hidden_states = (
hidden_states = layer_outputs[0] layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs
)
# AvgPool (time/2) + LayerNorm hidden_states = hidden_states.permute(0, 2, 1)
# hidden_states: (batch, seq_len, hidden_size)
hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len)
hidden_states = self.avg_pooler(hidden_states) hidden_states = self.avg_pooler(hidden_states)
hidden_states = hidden_states.permute( hidden_states = hidden_states.permute(0, 2, 1)
0, 2, 1
) # (batch, seq_len/2, hidden_size)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
...@@ -193,22 +185,6 @@ class AudioFlamingo3MultiModalProjector(nn.Module): ...@@ -193,22 +185,6 @@ class AudioFlamingo3MultiModalProjector(nn.Module):
return hidden_states return hidden_states
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_audioflamingo3_field_config,
)
return super()._parse_audio_data(data)
class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(AudioFlamingo3Config) return self.ctx.get_hf_config(AudioFlamingo3Config)
...@@ -217,20 +193,17 @@ class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): ...@@ -217,20 +193,17 @@ class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs) return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
def get_feature_extractor(self, **kwargs: object): def get_feature_extractor(self, **kwargs: object):
hf_processor = self.get_hf_processor(**kwargs) return self.get_hf_processor(**kwargs).feature_extractor
feature_extractor = hf_processor.feature_extractor
return feature_extractor
def get_data_parser(self): def get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser( return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate, target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(), expected_hidden_size=self._get_expected_hidden_size(),
) )
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1} return {"audio": None}
class AudioFlamingo3DummyInputsBuilder( class AudioFlamingo3DummyInputsBuilder(
...@@ -248,9 +221,10 @@ class AudioFlamingo3DummyInputsBuilder( ...@@ -248,9 +221,10 @@ class AudioFlamingo3DummyInputsBuilder(
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions], mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict: ) -> MultiModalDataDict:
hf_processor = self.info.get_hf_processor()
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = MAX_AUDIO_LEN * sampling_rate audio_len = int(hf_processor.max_audio_len * sampling_rate)
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") audio_overrides = mm_options.get("audio")
...@@ -284,6 +258,118 @@ def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -284,6 +258,118 @@ def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]):
) )
def _get_audio_post_pool_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
conv_lengths = (input_lengths - 1) // 2 + 1
return (conv_lengths - 2) // 2 + 1
def _build_audio_encoder_attention_mask(
feature_attention_mask: torch.Tensor,
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
input_lengths = feature_attention_mask.sum(-1).to(torch.long)
conv_lengths = (input_lengths - 1) // 2 + 1
batch_size, max_mel_seq_len = feature_attention_mask.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
seq_range = (
torch.arange(
max_seq_len,
dtype=conv_lengths.dtype,
device=conv_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
padding_mask = seq_range >= conv_lengths[:, None]
attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
attention_mask = attention_mask.to(dtype=dtype, device=device)
attention_mask.masked_fill_(padding_mask[:, None, None, :], float("-inf"))
return attention_mask
def _flatten_valid_audio_embeddings(
audio_embeddings: torch.Tensor,
feature_attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_lengths = feature_attention_mask.sum(-1).to(torch.long)
output_lengths = _get_audio_post_pool_output_lengths(input_lengths)
valid_mask = (
torch.arange(audio_embeddings.shape[1], device=output_lengths.device)[None, :]
< output_lengths[:, None]
)
return audio_embeddings[valid_mask], output_lengths
def _count_audio_tokens_from_mask(
feature_attention_mask: torch.Tensor | list[torch.Tensor],
chunk_counts: torch.Tensor | list[torch.Tensor] | list[int] | None,
item_idx: int,
) -> int:
if chunk_counts is not None:
if isinstance(chunk_counts, torch.Tensor):
counts = chunk_counts.tolist()
elif chunk_counts and isinstance(chunk_counts[0], torch.Tensor):
counts = [count.item() for count in chunk_counts]
else:
counts = chunk_counts
start_idx = sum(counts[:item_idx])
count = counts[item_idx]
end_idx = start_idx + count
if isinstance(feature_attention_mask, list):
sample_mask = feature_attention_mask[start_idx:end_idx]
if len(sample_mask) == 0:
raise ValueError("Expected non-empty audio mask slice.")
if isinstance(sample_mask[0], torch.Tensor):
sample_mask = torch.stack(sample_mask)
else:
sample_mask = torch.tensor(sample_mask)
else:
sample_mask = feature_attention_mask[start_idx:end_idx]
else:
if isinstance(feature_attention_mask, list):
sample_mask = feature_attention_mask[item_idx]
else:
sample_mask = feature_attention_mask[item_idx]
if sample_mask.ndim == 1:
sample_input_lengths = sample_mask.sum().unsqueeze(0)
else:
# Match the HF processor, which derives placeholder lengths from the
# total pre-encoder feature length for each original audio sample.
sample_input_lengths = sample_mask.sum().reshape(1)
post_lengths = _get_audio_post_pool_output_lengths(
sample_input_lengths.to(torch.long)
)
return int(post_lengths[0].item())
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_audioflamingo3_field_config,
)
return super()._parse_audio_data(data)
class AudioFlamingo3MultiModalProcessor( class AudioFlamingo3MultiModalProcessor(
BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo] BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo]
): ):
...@@ -303,13 +389,13 @@ class AudioFlamingo3MultiModalProcessor( ...@@ -303,13 +389,13 @@ class AudioFlamingo3MultiModalProcessor(
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
feature_extractor = self.info.get_feature_extractor(**mm_kwargs) processor = self.info.get_hf_processor(**mm_kwargs)
feature_extractor = processor.feature_extractor
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
) )
# Calculate chunk counts
audio_list = mm_data.get("audio") audio_list = mm_data.get("audio")
if not isinstance(audio_list, list): if not isinstance(audio_list, list):
audio_list = [audio_list] audio_list = [audio_list]
...@@ -318,8 +404,7 @@ class AudioFlamingo3MultiModalProcessor( ...@@ -318,8 +404,7 @@ class AudioFlamingo3MultiModalProcessor(
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
chunk_length = feature_extractor.chunk_length chunk_length = feature_extractor.chunk_length
window_size = int(sampling_rate * chunk_length) window_size = int(sampling_rate * chunk_length)
# MAX_AUDIO_LEN is 10 * 60 in HF processor. max_windows = int(processor.max_audio_len // chunk_length)
max_windows = int(MAX_AUDIO_LEN // chunk_length)
for audio in audio_list: for audio in audio_list:
# audio is numpy array or list # audio is numpy array or list
...@@ -364,7 +449,6 @@ class AudioFlamingo3MultiModalProcessor( ...@@ -364,7 +449,6 @@ class AudioFlamingo3MultiModalProcessor(
audio_token = getattr(processor, "audio_token", "<sound>") audio_token = getattr(processor, "audio_token", "<sound>")
audio_token_id = vocab.get(audio_token) audio_token_id = vocab.get(audio_token)
if audio_token_id is None: if audio_token_id is None:
# Fallback if not found, though it should be there
audio_token_id = processor.audio_token_id audio_token_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.get_data() out_mm_data = out_mm_kwargs.get_data()
...@@ -373,38 +457,11 @@ class AudioFlamingo3MultiModalProcessor( ...@@ -373,38 +457,11 @@ class AudioFlamingo3MultiModalProcessor(
def get_replacement_audioflamingo3(item_idx: int): def get_replacement_audioflamingo3(item_idx: int):
if feature_attention_mask is not None: if feature_attention_mask is not None:
if chunk_counts is not None: num_features = _count_audio_tokens_from_mask(
counts = ( feature_attention_mask,
chunk_counts.tolist() chunk_counts,
if isinstance(chunk_counts, torch.Tensor) item_idx,
else chunk_counts )
)
start_idx = sum(counts[:item_idx])
count = counts[item_idx]
end_idx = start_idx + count
if isinstance(feature_attention_mask, list):
mask_list = feature_attention_mask[start_idx:end_idx]
if len(mask_list) > 0 and isinstance(
mask_list[0], torch.Tensor
):
mask = torch.stack(mask_list)
else:
mask = torch.tensor(mask_list)
else:
mask = feature_attention_mask[start_idx:end_idx]
else:
# feature_attention_mask is list[Tensor] or Tensor
if isinstance(feature_attention_mask, list):
mask = feature_attention_mask[item_idx]
else:
mask = feature_attention_mask[item_idx].unsqueeze(0)
# mask shape: (num_chunks, 3000)
input_lengths = mask.sum(-1)
conv_lengths = (input_lengths - 1) // 2 + 1
audio_output_lengths = (conv_lengths - 2) // 2 + 1
num_features = audio_output_lengths.sum().item()
else: else:
audio_embeds = out_mm_data["audio_embeds"][item_idx] audio_embeds = out_mm_data["audio_embeds"][item_idx]
num_features = audio_embeds.shape[0] num_features = audio_embeds.shape[0]
...@@ -435,13 +492,6 @@ class AudioFlamingo3MultiModalProcessor( ...@@ -435,13 +492,6 @@ class AudioFlamingo3MultiModalProcessor(
class AudioFlamingo3ForConditionalGeneration( class AudioFlamingo3ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
): ):
"""
AudioFlamingo3 model for conditional generation.
This model integrates a Whisper-based audio encoder with a Qwen2 language model.
It supports multi-chunk audio processing.
"""
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
...@@ -517,6 +567,25 @@ class AudioFlamingo3ForConditionalGeneration( ...@@ -517,6 +567,25 @@ class AudioFlamingo3ForConditionalGeneration(
audio_embeds = audio_input["audio_embeds"] audio_embeds = audio_input["audio_embeds"]
return tuple(audio_embeds) return tuple(audio_embeds)
(
input_features,
feature_attention_mask,
chunk_counts,
) = self._normalize_audio_feature_inputs(audio_input)
audio_hidden_states = self._encode_audio_features(
input_features,
feature_attention_mask,
)
audio_features = self.multi_modal_projector(audio_hidden_states)
return self._group_audio_embeddings(
audio_features,
feature_attention_mask,
chunk_counts,
)
def _normalize_audio_feature_inputs(
self, audio_input: AudioFlamingo3FeatureInputs
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
input_features = audio_input["input_features"] input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"] feature_attention_mask = audio_input["feature_attention_mask"]
chunk_counts = audio_input.get("chunk_counts") chunk_counts = audio_input.get("chunk_counts")
...@@ -534,66 +603,36 @@ class AudioFlamingo3ForConditionalGeneration( ...@@ -534,66 +603,36 @@ class AudioFlamingo3ForConditionalGeneration(
and chunk_counts and chunk_counts
and isinstance(chunk_counts[0], torch.Tensor) and isinstance(chunk_counts[0], torch.Tensor)
): ):
chunk_counts = [c.item() for c in chunk_counts] chunk_counts = [count.item() for count in chunk_counts]
# Calculate output lengths
input_lengths = feature_attention_mask.sum(-1)
# Conv downsampling
conv_lengths = (input_lengths - 1) // 2 + 1
# AvgPool downsampling
audio_output_lengths = (conv_lengths - 2) // 2 + 1
batch_size, _, max_mel_seq_len = input_features.shape
# Calculate max_seq_len after convs (before pooling) for attention mask
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=conv_lengths.dtype,
device=conv_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( return input_features, feature_attention_mask, chunk_counts
batch_size, 1, max_seq_len, max_seq_len
) def _encode_audio_features(
audio_attention_mask = audio_attention_mask_.to( self,
input_features: torch.Tensor,
feature_attention_mask: torch.Tensor,
) -> torch.Tensor:
audio_attention_mask = _build_audio_encoder_attention_mask(
feature_attention_mask,
dtype=self.audio_tower.conv1.weight.dtype, dtype=self.audio_tower.conv1.weight.dtype,
device=self.audio_tower.conv1.weight.device, device=self.audio_tower.conv1.weight.device,
) )
audio_attention_mask[audio_attention_mask_] = float("-inf")
# Forward pass return self.audio_tower(input_features, attention_mask=audio_attention_mask)
audio_features = self.audio_tower(
input_features, attention_mask=audio_attention_mask
)
# Project def _group_audio_embeddings(
audio_features = self.multi_modal_projector(audio_features) self,
audio_features: torch.Tensor,
# Masking after pooling feature_attention_mask: torch.Tensor,
num_audios, max_audio_tokens, embed_dim = audio_features.shape chunk_counts: list[int],
audio_output_lengths = audio_output_lengths.unsqueeze(1) ) -> tuple[torch.Tensor, ...]:
audio_features_mask = ( masked_audio_features, audio_output_lengths = _flatten_valid_audio_embeddings(
torch.arange(max_audio_tokens) audio_features,
.expand(num_audios, max_audio_tokens) feature_attention_mask,
.to(audio_output_lengths.device)
< audio_output_lengths
) )
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
# Split to tuple of embeddings for individual audio input.
chunk_embeddings = torch.split( chunk_embeddings = torch.split(
masked_audio_features, audio_output_lengths.flatten().tolist() masked_audio_features,
audio_output_lengths.tolist(),
) )
grouped_embeddings = [] grouped_embeddings = []
...@@ -613,7 +652,7 @@ class AudioFlamingo3ForConditionalGeneration( ...@@ -613,7 +652,7 @@ class AudioFlamingo3ForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""MusicFlamingo model adapter. # Copyright 2026 The vLLM team.
# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
MusicFlamingo shares the AudioFlamingo3 architecture, so we reuse the same from collections.abc import Callable, Mapping, Sequence
implementation and multimodal processor, while accepting MusicFlamingo config from math import pi
and processor classes when available. from typing import Annotated, Any, Optional, TypeAlias
"""
from collections.abc import Mapping import torch
from torch import Tensor, broadcast_tensors, nn
from transformers.models.audioflamingo3 import ( from transformers import BatchFeature
AudioFlamingo3Config, from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
AudioFlamingo3Processor, from transformers.models.musicflamingo import (
MusicFlamingoConfig,
MusicFlamingoProcessor,
) )
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import BaseProcessingInfo from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityData,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.utils.tensor_schema import TensorShape
from .audioflamingo3 import ( from .audioflamingo3 import (
AudioFlamingo3DummyInputsBuilder, AudioFlamingo3DummyInputsBuilder,
AudioFlamingo3EmbeddingInputs,
AudioFlamingo3Encoder,
AudioFlamingo3FeatureInputs,
AudioFlamingo3ForConditionalGeneration, AudioFlamingo3ForConditionalGeneration,
AudioFlamingo3MultiModalDataParser, AudioFlamingo3MultiModalDataParser,
AudioFlamingo3MultiModalProcessor, AudioFlamingo3MultiModalProcessor,
AudioFlamingo3MultiModalProjector,
AudioFlamingo3ProcessingInfo,
_audioflamingo3_field_config,
_count_audio_tokens_from_mask,
) )
try:
# Optional dependency: use MusicFlamingo classes when transformers provides them. def rotate_half(x):
from transformers.models.musicflamingo import ( x = x.reshape(*x.shape[:-1], -1, 2)
MusicFlamingoConfig, x1, x2 = x.unbind(dim=-1)
MusicFlamingoProcessor, x = torch.stack((-x2, x1), dim=-1)
) return x.flatten(-2)
except Exception: # pragma: no cover - optional dependency
MusicFlamingoConfig = None
MusicFlamingoProcessor = None def apply_rotary_time_emb(hidden_states, cos, sin):
original_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float64)
class MusicFlamingoProcessingInfo(BaseProcessingInfo): cos = cos.to(hidden_states)
def get_hf_config(self): sin = sin.to(hidden_states)
if MusicFlamingoConfig is None: rot_dim = cos.shape[-1]
return self.ctx.get_hf_config(AudioFlamingo3Config) if rot_dim > hidden_states.shape[-1]:
return self.ctx.get_hf_config((MusicFlamingoConfig, AudioFlamingo3Config)) raise ValueError(
f"feature dimension {hidden_states.shape[-1]} is not of "
def get_hf_processor(self, **kwargs: object): f"sufficient size to rotate in all the positions {rot_dim}"
if MusicFlamingoProcessor is None:
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
# Tuple triggers AutoProcessor path and accepts either processor class.
return self.ctx.get_hf_processor(
(MusicFlamingoProcessor, AudioFlamingo3Processor), **kwargs
) )
def get_feature_extractor(self, **kwargs: object): rotated = hidden_states[..., :rot_dim]
hf_processor = self.get_hf_processor(**kwargs) passthrough = hidden_states[..., rot_dim:]
return hf_processor.feature_extractor rotated = (rotated * cos) + (rotate_half(rotated) * sin)
return torch.cat((rotated, passthrough), dim=-1).to(original_dtype)
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser( class MusicFlamingoRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, config: MusicFlamingoConfig, device=None):
super().__init__()
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_type = self.config.rope_parameters["rope_type"]
rope_init_fn: Callable = self.compute_default_rope_parameters
if self.rope_type != "default":
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
position_angles = self._compute_position_angles(self.inv_freq)
self.register_buffer("position_angles", position_angles, persistent=False)
@staticmethod
def compute_default_rope_parameters(
config: MusicFlamingoConfig | None = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
) -> tuple["torch.Tensor", float]:
del seq_len
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or (
config.hidden_size // config.num_attention_heads
)
attention_factor = 1.0
inv_freq = 1.0 / (
base
** (
torch.arange(0, dim, 2, dtype=torch.int64).to(
device=device,
dtype=torch.float,
)
/ dim
)
)
return inv_freq, attention_factor
def _compute_position_angles(self, inv_freq):
positions = torch.arange(
int(self.max_seq_len_cached),
device=inv_freq.device,
dtype=inv_freq.dtype,
)
positions = positions / self.max_seq_len_cached * (2 * pi)
position_angles = positions.unsqueeze(-1) * inv_freq
position_angles = torch.repeat_interleave(position_angles, 2, dim=-1)
return position_angles.to(dtype=inv_freq.dtype)
@torch.no_grad()
def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]:
batch_positions = torch.arange(
timestamps.shape[0],
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)
batch_positions = batch_positions / self.max_seq_len_cached
batch_freqs = batch_positions.unsqueeze(-1) * self.inv_freq
batch_freqs = torch.repeat_interleave(batch_freqs, 2, dim=-1)
batch_freqs = batch_freqs[:, None, :]
time_freqs = self.position_angles[:seq_len][None, :, :]
batch_freqs, time_freqs = broadcast_tensors(batch_freqs, time_freqs)
freqs = torch.cat((batch_freqs, time_freqs), dim=-1)
angle = (-timestamps * 2 * pi).to(freqs)
freqs = freqs * angle.unsqueeze(-1)
return freqs.cos(), freqs.sin()
class MusicFlamingoFeatureInputs(AudioFlamingo3FeatureInputs):
rote_timestamps: Annotated[
torch.Tensor,
TensorShape(
"num_chunks",
"num_audio_time_steps",
dynamic_dims={"num_audio_time_steps"},
),
]
MusicFlamingoEmbeddingInputs = AudioFlamingo3EmbeddingInputs
MusicFlamingoInputs: TypeAlias = (
MusicFlamingoFeatureInputs | MusicFlamingoEmbeddingInputs
)
class MusicFlamingoEncoder(AudioFlamingo3Encoder):
pass
class MusicFlamingoMultiModalProjector(AudioFlamingo3MultiModalProjector):
pass
class MusicFlamingoProcessingInfo(AudioFlamingo3ProcessingInfo):
def get_hf_config(self) -> MusicFlamingoConfig:
return self.ctx.get_hf_config(MusicFlamingoConfig)
def get_hf_processor(self, **kwargs: object) -> MusicFlamingoProcessor:
return self.ctx.get_hf_processor(MusicFlamingoProcessor, **kwargs)
def get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.get_feature_extractor()
return MusicFlamingoMultiModalDataParser(
target_sr=feature_extractor.sampling_rate, target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(), expected_hidden_size=self._get_expected_hidden_size(),
) )
...@@ -67,13 +213,230 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo): ...@@ -67,13 +213,230 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo):
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder): class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
pass def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
return hf_processor.audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
hf_processor = self.info.get_hf_processor()
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = int(hf_processor.max_audio_len * sampling_rate)
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio")
return {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
def _musicflamingo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
fields = dict(_audioflamingo3_field_config(hf_inputs))
chunk_counts = hf_inputs.get("chunk_counts")
if chunk_counts is not None:
fields["rote_timestamps"] = MultiModalFieldConfig.flat_from_sizes(
"audio", chunk_counts, dim=0
)
else:
fields["rote_timestamps"] = MultiModalFieldConfig.batched("audio")
return fields
class MusicFlamingoMultiModalDataParser(AudioFlamingo3MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_musicflamingo_field_config,
)
return super()._parse_audio_data(data)
class MusicFlamingoMultiModalProcessor(AudioFlamingo3MultiModalProcessor):
def _call_hf_processor(
self,
prompt: str,
mm_data: dict[str, object],
mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
audio_data = mm_data.get("audio")
if audio_data is None:
return outputs
audio_list = audio_data if isinstance(audio_data, list) else [audio_data]
if len(audio_list) == 0:
return outputs
processor = self.info.get_hf_processor(**mm_kwargs)
feature_extractor = processor.feature_extractor
sampling_rate = feature_extractor.sampling_rate
chunk_length = feature_extractor.chunk_length
window_size = int(sampling_rate * chunk_length)
max_windows = int(processor.max_audio_len // chunk_length)
chunk_counts = []
for audio in audio_list:
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
n_win = max(1, (n_samples + window_size - 1) // window_size)
chunk_counts.append(min(n_win, max_windows))
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
if "rote_timestamps" not in outputs:
raise KeyError(
"MusicFlamingoProcessor output must include `rote_timestamps`."
)
return outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _musicflamingo_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
audio_token = processor.audio_token
audio_token_id = vocab.get(audio_token, processor.audio_token_id)
audio_bos_token = processor.audio_bos_token
audio_bos_token_id = vocab.get(audio_bos_token, processor.audio_bos_token_id)
audio_eos_token = processor.audio_eos_token
audio_eos_token_id = vocab.get(audio_eos_token, processor.audio_eos_token_id)
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
chunk_counts = out_mm_data.get("chunk_counts")
def get_replacement_musicflamingo(item_idx: int):
if feature_attention_mask is not None:
num_features = _count_audio_tokens_from_mask(
feature_attention_mask,
chunk_counts,
item_idx,
)
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
num_features = audio_embeds.shape[0]
if num_features == 0:
raise ValueError("Audio is too short")
full_tokens = [
audio_bos_token_id,
*([audio_token_id] * int(num_features)),
audio_eos_token_id,
]
return PromptUpdateDetails.select_token_id(
full_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_musicflamingo,
)
]
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
AudioFlamingo3MultiModalProcessor, MusicFlamingoMultiModalProcessor,
info=MusicFlamingoProcessingInfo, info=MusicFlamingoProcessingInfo,
dummy_inputs=MusicFlamingoDummyInputsBuilder, dummy_inputs=MusicFlamingoDummyInputsBuilder,
) )
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
"""MusicFlamingo model for conditional generation.""" """vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> MusicFlamingoInputs | None:
rote_timestamps = kwargs.pop("rote_timestamps", None)
audio_input = super()._parse_and_validate_audio_input(**kwargs)
if audio_input is None or audio_input["type"] == "audio_embeds":
return audio_input
return MusicFlamingoFeatureInputs(
type="audio_features",
input_features=audio_input["input_features"],
feature_attention_mask=audio_input["feature_attention_mask"],
chunk_counts=audio_input["chunk_counts"],
rote_timestamps=rote_timestamps,
)
def _process_audio_input(
self, audio_input: MusicFlamingoInputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
return super()._process_audio_input(audio_input)
rote_timestamps = audio_input["rote_timestamps"]
if rote_timestamps is None:
raise ValueError(
"MusicFlamingo audio feature inputs must include `rote_timestamps`."
)
if isinstance(rote_timestamps, list):
rote_timestamps = torch.cat(rote_timestamps, dim=0)
(
input_features,
feature_attention_mask,
chunk_counts,
) = self._normalize_audio_feature_inputs(audio_input)
hidden_states = self._encode_audio_features(
input_features,
feature_attention_mask,
)
cos, sin = self.pos_emb(
rote_timestamps.to(hidden_states.device),
seq_len=hidden_states.shape[-2],
)
hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
audio_features = self.multi_modal_projector(hidden_states)
return self._group_audio_embeddings(
audio_features,
feature_attention_mask,
chunk_counts,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment