test_qwen3_omni.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for Qwen3 Omni audio processing and sample rate handling."""

from typing import Any

import numpy as np
import pytest

from vllm.multimodal import MULTIMODAL_REGISTRY

from ...utils import build_model_context


@pytest.mark.parametrize("model_id", ["Qwen/Qwen3-Omni-30B-A3B-Instruct"])
@pytest.mark.parametrize(
    ("audio_sample_rate", "audio_duration_sec"),
    [
        (16000, 1.0),  # Native Whisper sample rate, 1 second
        (16000, 2.0),  # Native Whisper sample rate, 2 seconds
    ],
)
def test_processor_with_audio_sample_rate(
    model_id: str,
    audio_sample_rate: int,
    audio_duration_sec: float,
) -> None:
    """
    Test that vLLM's processor generates expected outputs with audio_sample_rate.

    This validates the reviewer's request that we test the actual processor
    can handle different audio_sample_rate values and generate audio tokens.
    """
    # Setup: Build model context and processor
    ctx = build_model_context(
        model_id,
        limit_mm_per_prompt={"audio": 1, "image": 0, "video": 0},
    )
    processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
    tokenizer = processor.info.get_tokenizer()

    # Create audio data at the specified sample rate
    audio_length = int(audio_sample_rate * audio_duration_sec)
    rng = np.random.RandomState(42)
    audio_data = rng.rand(audio_length).astype(np.float32)

    # Build prompt with audio placeholder
    prompt = "<|audio_start|><|audio_pad|><|audio_end|>"
    mm_data = {"audio": [(audio_data, audio_sample_rate)]}

    # Execute: Apply processor with audio_sample_rate in mm_kwargs
    hf_processor_mm_kwargs: dict[str, Any] = {
        "audio_sample_rate": audio_sample_rate,
    }
    processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)

    # Assert: Verify audio tokens are generated
    hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
    audio_token_id = tokenizer.convert_tokens_to_ids(hf_processor.audio_token)
    aud_tok_count = processed_inputs["prompt_token_ids"].count(audio_token_id)

    # Audio should generate at least 1 token
    assert aud_tok_count >= 1, (
        f"Expected at least 1 audio token but got {aud_tok_count}. "
        f"sample_rate: {audio_sample_rate}Hz, duration: {audio_duration_sec}s"
    )


@pytest.mark.parametrize("model_id", ["Qwen/Qwen3-Omni-30B-A3B-Instruct"])
def test_longer_audio_generates_more_tokens(model_id: str) -> None:
    """
    Test that longer audio generates more tokens than shorter audio.

    This validates that audio_sample_rate is being used correctly by checking
    that audio duration affects token count as expected.
    """
    ctx = build_model_context(
        model_id,
        limit_mm_per_prompt={"audio": 1, "image": 0, "video": 0},
    )
    processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
    tokenizer = processor.info.get_tokenizer()

    audio_sample_rate = 16000
    rng = np.random.RandomState(42)

    def get_token_count(duration: float) -> int:
        audio_length = int(audio_sample_rate * duration)
        audio_data = rng.rand(audio_length).astype(np.float32)
        prompt = "<|audio_start|><|audio_pad|><|audio_end|>"
        mm_data = {"audio": [(audio_data, audio_sample_rate)]}
        hf_processor_mm_kwargs: dict[str, Any] = {
            "audio_sample_rate": audio_sample_rate,
        }
        processed = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
        hf_proc = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
        audio_token_id = tokenizer.convert_tokens_to_ids(hf_proc.audio_token)
        return processed["prompt_token_ids"].count(audio_token_id)

    # Get token counts for different durations
    short_tokens = get_token_count(1.0)
    long_tokens = get_token_count(2.0)

    # Longer audio should produce more tokens
    assert long_tokens > short_tokens, (
        f"Expected longer audio (2s) to have more tokens than shorter (1s). "
        f"Got short={short_tokens}, long={long_tokens}"
    )


class TestQwen3OmniAudioSampleRatePreservation:
    """Test that audio_sample_rate is preserved during kwargs restructuring.

    These tests validate the fix for the audio_sample_rate bug in Qwen3 Omni
    where the parameter was lost during kwargs restructuring.
    """

    @staticmethod
    def _process_kwargs(
        mm_kwargs: dict[str, Any],
        tok_kwargs: dict[str, Any],
        transformers_version: str = "4.57.0",
    ) -> dict[str, Any]:
        """
        Helper method to simulate kwargs processing logic from production code.

        This method simulates the kwargs restructuring that happens in the
        Qwen3 Omni model when transformers < 4.58.0. By centralizing this
        logic, we make tests easier to maintain if the production logic changes.

        Args:
            mm_kwargs: Multimodal kwargs (e.g., audio_sample_rate, truncation)
            tok_kwargs: Tokenizer kwargs (e.g., truncation)
            transformers_version: Version string to test against (default: "4.57.0")

        Returns:
            Processed kwargs dictionary with restructured audio_kwargs and text_kwargs
        """
        from packaging.version import Version

        mm_kwargs_copy = dict(mm_kwargs)
        tok_kwargs_copy = dict(tok_kwargs)

        if Version(transformers_version) < Version("4.58.0"):
            # Extract audio_sample_rate before restructuring (THE FIX)
            audio_sample_rate = mm_kwargs_copy.pop("audio_sample_rate", None)

            # Restructure kwargs
            mm_kwargs_copy["audio_kwargs"] = {
                "truncation": mm_kwargs_copy.pop("truncation", False)
            }
            mm_kwargs_copy["text_kwargs"] = {
                "truncation": tok_kwargs_copy.pop("truncation", False)
            }

            # Put audio_sample_rate into audio_kwargs (THE FIX)
            if audio_sample_rate is not None:
                mm_kwargs_copy["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate

        return mm_kwargs_copy

    def test_audio_sample_rate_preserved_in_audio_kwargs(self) -> None:
        """
        Test that audio_sample_rate is moved from top-level mm_kwargs
        into audio_kwargs during kwargs restructuring.

        This is the core fix: when transformers < 4.58.0, the code
        restructures kwargs into audio_kwargs and text_kwargs, and
        audio_sample_rate must be preserved in audio_kwargs.
        """
        # Setup: Create mm_kwargs with audio_sample_rate at top level
        mm_kwargs: dict[str, Any] = {
            "audio_sample_rate": 16000,
            "truncation": True,
        }
        tok_kwargs: dict[str, Any] = {
            "truncation": False,
        }

        # Execute: Process kwargs using helper method
        result = self._process_kwargs(mm_kwargs, tok_kwargs)

        # Assert: Verify audio_sample_rate is in audio_kwargs
        assert "audio_kwargs" in result
        assert "audio_sample_rate" in result["audio_kwargs"]
        assert result["audio_kwargs"]["audio_sample_rate"] == 16000

        # Assert: Verify truncation is also in audio_kwargs
        assert result["audio_kwargs"]["truncation"] is True

        # Assert: Verify text_kwargs is created correctly
        assert "text_kwargs" in result
        assert result["text_kwargs"]["truncation"] is False

    def test_audio_sample_rate_absent_when_not_provided(self) -> None:
        """
        Test that when audio_sample_rate is not provided in mm_kwargs,
        the restructured audio_kwargs doesn't contain it.
        """
        # Setup: Create mm_kwargs WITHOUT audio_sample_rate
        mm_kwargs: dict[str, Any] = {
            "truncation": True,
        }
        tok_kwargs: dict[str, Any] = {
            "truncation": False,
        }

        # Execute: Process kwargs using helper method
        result = self._process_kwargs(mm_kwargs, tok_kwargs)

        # Assert: Verify audio_sample_rate is NOT in audio_kwargs
        assert "audio_kwargs" in result
        assert "audio_sample_rate" not in result["audio_kwargs"]

        # Assert: Verify truncation is still in audio_kwargs
        assert result["audio_kwargs"]["truncation"] is True

    @pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 24000, 44100, 48000])
    def test_various_audio_sample_rates_preserved(self, sample_rate: int) -> None:
        """
        Test that various common audio sample rates are preserved.

        Common sample rates:
        - 8000: Telephone quality
        - 16000: Wideband speech (Qwen3 Omni default)
        - 22050: Low-quality audio
        - 24000: High-quality speech
        - 44100: CD quality
        - 48000: Professional audio
        """
        # Setup: Create mm_kwargs with specific sample rate
        mm_kwargs: dict[str, Any] = {
            "audio_sample_rate": sample_rate,
            "truncation": True,
        }
        tok_kwargs: dict[str, Any] = {"truncation": False}

        # Execute: Process kwargs using helper method
        result = self._process_kwargs(mm_kwargs, tok_kwargs)

        # Assert: Verify the specific sample rate is preserved
        assert result["audio_kwargs"]["audio_sample_rate"] == sample_rate

    def test_kwargs_unchanged_for_newer_transformers_version(self) -> None:
        """
        Test that kwargs structure remains unchanged for transformers >= 4.58.0.

        This test ensures that when transformers version is 4.58.0 or higher,
        the kwargs restructuring is bypassed and audio_sample_rate remains
        at the top level as originally passed.
        """
        from packaging.version import Version

        # Setup: Create mm_kwargs with audio_sample_rate at top level
        mm_kwargs: dict[str, Any] = {
            "audio_sample_rate": 16000,
            "truncation": True,
        }
        tok_kwargs: dict[str, Any] = {
            "truncation": False,
        }

        # Execute: Simulate with transformers >= 4.58.0
        mm_kwargs_copy = dict(mm_kwargs)
        tok_kwargs_copy = dict(tok_kwargs)

        transformers_ver = "4.58.0"  # Version that bypasses restructuring
        if Version(transformers_ver) < Version("4.58.0"):
            # This block should NOT execute for >= 4.58.0
            audio_sample_rate = mm_kwargs_copy.pop("audio_sample_rate", None)
            mm_kwargs_copy["audio_kwargs"] = {
                "truncation": mm_kwargs_copy.pop("truncation", False)
            }
            mm_kwargs_copy["text_kwargs"] = {
                "truncation": tok_kwargs_copy.pop("truncation", False)
            }
            if audio_sample_rate is not None:
                mm_kwargs_copy["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate

        # Assert: Verify kwargs structure is unchanged
        assert "audio_kwargs" not in mm_kwargs_copy
        assert "text_kwargs" not in mm_kwargs_copy
        assert mm_kwargs_copy["audio_sample_rate"] == 16000
        assert mm_kwargs_copy["truncation"] is True
        assert tok_kwargs_copy["truncation"] is False