test_audioflamingo3.py 7.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The vLLM team.
# Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

import numpy as np
import pytest
import torch
from transformers import PretrainedConfig

from tests.models.registry import HF_EXAMPLE_MODELS


class MockAudioFlamingo3Config(PretrainedConfig):
    model_type = "audioflamingo3"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.audio_config = PretrainedConfig()
        self.text_config = PretrainedConfig()


class MockAudioFlamingo3Processor:
    def __init__(self):
        self.audio_token = "<sound>"
        self.audio_token_id = 12345
43
        self.max_audio_len = 60
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        self.feature_extractor = MockFeatureExtractor()

    def __call__(self, text=None, audios=None, **kwargs):
        return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]}


class MockFeatureExtractor:
    def __init__(self):
        self.sampling_rate = 16000
        self.chunk_length = 30


@pytest.fixture
def mock_ctx():
    config = MockAudioFlamingo3Config()

    ctx = MagicMock()
    ctx.get_hf_config.return_value = config
    ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor()
    ctx.model_config.hf_config = config
    return ctx


@pytest.fixture(autouse=True)
def check_transformers_version():
    model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
    model_info.check_transformers_version(on_fail="skip")


def test_audio_chunk_counting(mock_ctx):
    from vllm.model_executor.models.audioflamingo3 import (
        AudioFlamingo3DummyInputsBuilder,
        AudioFlamingo3MultiModalProcessor,
        AudioFlamingo3ProcessingInfo,
    )

    info = AudioFlamingo3ProcessingInfo(mock_ctx)
    processor = AudioFlamingo3MultiModalProcessor(
        info, AudioFlamingo3DummyInputsBuilder(info)
    )

    sr = 16000
    audio_1 = np.zeros(30 * sr)
87
    audio_2 = np.zeros(75 * sr)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    mm_data = {"audio": [audio_1, audio_2]}
    prompt = "<|user|>Listen.<|end|>"

    from vllm.multimodal.processing import BaseMultiModalProcessor

    def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs):
        return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)}

    with pytest.MonkeyPatch.context() as mp:
        mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call)

        processed = processor._call_hf_processor(prompt, mm_data, {}, {})

        chunk_counts = processed["chunk_counts"]

        assert chunk_counts[0].item() == 1
        assert chunk_counts[1].item() == 2
        assert len(chunk_counts) == 2


def test_dummy_data_generation(mock_ctx):
    from vllm.model_executor.models.audioflamingo3 import (
        AudioFlamingo3DummyInputsBuilder,
        AudioFlamingo3ProcessingInfo,
    )

    info = AudioFlamingo3ProcessingInfo(mock_ctx)
    builder = AudioFlamingo3DummyInputsBuilder(info)

    mm_counts = {"audio": 2}
119
    dummy_data = builder.get_dummy_mm_data(100, mm_counts, {})
120
121
122
123

    assert "audio" in dummy_data
    assert len(dummy_data["audio"]) == 2

124
    expected_len = 60 * 16000
125
    assert len(dummy_data["audio"][0]) == expected_len
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227


def test_audio_token_count_matches_hf_processor_math():
    from vllm.model_executor.models.audioflamingo3 import (
        _count_audio_tokens_from_mask,
    )

    feature_attention_mask = torch.zeros((3, 3000), dtype=torch.long)
    feature_attention_mask[0, :2999] = 1
    feature_attention_mask[1, :2999] = 1
    feature_attention_mask[2, :1500] = 1
    chunk_counts = torch.tensor([2, 1], dtype=torch.long)

    assert (
        _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 0) == 1499
    )
    assert _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 1) == 375


def test_audio_feature_pipeline_matches_hf_small_config():
    from transformers.models.audioflamingo3 import (
        modeling_audioflamingo3 as hf_audioflamingo3_modeling,
    )
    from transformers.models.audioflamingo3.configuration_audioflamingo3 import (
        AudioFlamingo3Config,
    )

    from vllm.model_executor.models.audioflamingo3 import (
        AudioFlamingo3Encoder,
        AudioFlamingo3MultiModalProjector,
        _build_audio_encoder_attention_mask,
        _flatten_valid_audio_embeddings,
    )

    text_config = {
        "model_type": "qwen2",
        "intermediate_size": 64,
        "initializer_range": 0.02,
        "hidden_size": 32,
        "max_position_embeddings": 1024,
        "num_hidden_layers": 2,
        "num_attention_heads": 4,
        "num_key_value_heads": 2,
        "vocab_size": 128,
        "pad_token_id": 1,
        "use_mrope": False,
    }
    audio_config = {
        "hidden_size": 16,
        "num_attention_heads": 4,
        "intermediate_size": 32,
        "num_hidden_layers": 2,
        "num_mel_bins": 80,
        "max_source_positions": 1500,
        "dropout": 0.0,
        "attention_dropout": 0.0,
        "activation_dropout": 0.0,
        "encoder_layerdrop": 0.0,
    }

    torch.manual_seed(0)
    config = AudioFlamingo3Config(
        text_config=text_config,
        audio_config=audio_config,
        audio_token_id=0,
    )
    hf_model = hf_audioflamingo3_modeling.AudioFlamingo3ForConditionalGeneration(
        config
    ).eval()

    vllm_encoder = AudioFlamingo3Encoder(config.audio_config).eval()
    vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict())

    vllm_projector = AudioFlamingo3MultiModalProjector(config).eval()
    vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict())

    input_features = torch.randn(3, 80, 3000)
    feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool)
    feature_attention_mask[0, :3000] = True
    feature_attention_mask[1, :2500] = True
    feature_attention_mask[2, :1500] = True

    hf_output = hf_model.get_audio_features(
        input_features,
        feature_attention_mask,
        return_dict=True,
    ).pooler_output
    vllm_attention_mask = _build_audio_encoder_attention_mask(
        feature_attention_mask,
        dtype=vllm_encoder.conv1.weight.dtype,
        device=vllm_encoder.conv1.weight.device,
    )
    vllm_hidden_states = vllm_encoder(
        input_features,
        attention_mask=vllm_attention_mask,
    )
    vllm_output, _ = _flatten_valid_audio_embeddings(
        vllm_projector(vllm_hidden_states),
        feature_attention_mask,
    )

    torch.testing.assert_close(vllm_output, hf_output)