"vllm/model_executor/models/qwen3_moe.py" did not exist on "82aee74526fe43a4eedf88327de18ae0b6fba8a5"
qwen2_audio.py 16.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable, Mapping, Sequence
27
from typing import Annotated, Any, Literal, TypeAlias
28
29
30

import torch
import torch.nn as nn
31
from transformers import BatchFeature
32
33
34
35
36
from transformers.models.qwen2_audio import (
    Qwen2AudioConfig,
    Qwen2AudioEncoder,
    Qwen2AudioProcessor,
)
37
from transformers.models.whisper import WhisperFeatureExtractor
38

39
from vllm.config import VllmConfig
40
from vllm.config.multimodal import BaseDummyOptions
41
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.multimodal.inputs import (
    AudioItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
57
    BaseDummyInputsBuilder,
58
59
60
61
62
63
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils.tensor_schema import TensorSchema, TensorShape
66

67
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
68
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
69
70
71


# # === Audio Inputs === #
72
73
74
75
76
77
class Qwen2AudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
    """
78

79
    type: Literal["audio_features"]
80
    input_features: Annotated[
81
        torch.Tensor | list[torch.Tensor],
82
83
        TensorShape("na", "nmb", 3000),
    ]
84

85
86
87
88
    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", 3000),
    ]
89
90


91
class Qwen2AudioEmbeddingInputs(TensorSchema):
92
    """
93
94
95
96
97
98
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
99

100
101
102
103
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
104
        TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
105
    ]
106
107


108
Qwen2AudioInputs: TypeAlias = Qwen2AudioFeatureInputs | Qwen2AudioEmbeddingInputs
109

110
111
112
113
114
115
116
117
118
119
120
121
122
# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):
    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


123
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
124
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
125
126
127
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
128
129


130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        audio_embeds=MultiModalFieldConfig.batched("audio"),
        input_features=MultiModalFieldConfig.batched("audio"),
        feature_attention_mask=MultiModalFieldConfig.batched("audio"),
    )


class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
    def _parse_audio_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_qwen2audio_field_config,
            )

        return super()._parse_audio_data(data)


154
155
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
156
157
        return self.ctx.get_hf_config(Qwen2AudioConfig)

158
    def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
159
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
160

161
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
162
        hf_processor = self.get_hf_processor(**kwargs)
163
164
165
166
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

167
168
169
170
171
172
173
174
175
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return Qwen2AudioMultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

176
177
178
179
    def get_target_channels(self) -> int:
        """Return target audio channels for Qwen2 Audio models (mono)."""
        return 1

180
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
181
        return {"audio": None}
182

183

184
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
185
186
187
188
189
190
191
192
193
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token

        return audio_token * num_audios

    def get_dummy_mm_data(
194
        self,
195
196
        seq_len: int,
        mm_counts: Mapping[str, int],
197
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
198
        mm_processor_kwargs: Mapping[str, object] | None = None,
199
    ) -> MultiModalDataDict:
200
201
202
        feature_extractor = self.info.get_feature_extractor(
            **(mm_processor_kwargs or {})
        )
203
204
205
206
207

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

208
209
        audio_overrides = mm_options.get("audio") if mm_options else None

210
        return {
211
212
213
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
214
215
        }

216

217
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
218
219
220
    def _call_hf_processor(
        self,
        prompt: str,
221
        mm_data: Mapping[str, object],
222
        mm_kwargs: Mapping[str, Any],
223
        tok_kwargs: Mapping[str, object],
224
    ) -> BatchFeature:
225
226
227
228
229
230
231
        # NOTE - we rename audios -> audio in mm data because transformers has
        # deprecated audios for the qwen2audio processor and will remove
        # support for it in transformers 4.54.
        audios = mm_data.pop("audios", [])
        if audios:
            mm_data["audio"] = audios

232
        # Text-only input not supported in composite processor
233
        if not mm_data.get("audio", []):
234
235
236
237
238
239
240
241
242
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
243

244
        return super()._call_hf_processor(
245
            prompt=prompt,
246
247
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
248
            tok_kwargs=tok_kwargs,
249
250
251
252
253
254
255
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
256
        return _qwen2audio_field_config(hf_inputs)
257

258
    def _get_prompt_updates(
259
260
        self,
        mm_items: MultiModalDataItems,
261
        hf_processor_mm_kwargs: Mapping[str, object],
262
        out_mm_kwargs: MultiModalKwargsItems,
263
    ) -> Sequence[PromptUpdate]:
264
265
266
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
267
268
269

        # Use getattr with default to be compatible with transformers<4.48
        audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
270
271
        audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
        audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
272

273
274
275
276
        audio_token_id = vocab[audio_token]
        audio_bos_id = vocab[audio_bos_token]
        audio_eos_id = vocab[audio_eos_token]

277
278
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
279
280
281
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
282
            assert isinstance(feature_attention_mask, torch.Tensor)
283
            _, audio_output_lens = _get_feat_extract_output_lengths(
284
285
                feature_attention_mask.sum(-1)
            )
286

287
288
            audio_output_lengths = audio_output_lens.tolist()

289
        def get_replacement_qwen2_audio(item_idx: int):
290
291
292
293
            if audio_output_lengths:
                num_features = audio_output_lengths[item_idx]
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
294
                assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
295
296
                num_features = audio_embeds.shape[0]

297
            if num_features == 0:
298
                audios = mm_items.get_items("audio", AudioProcessorItems)
299
300
                audio_len = audios.get_audio_length(item_idx)

301
302
303
304
                raise ValueError(
                    f"The audio (len={audio_len}) is too short "
                    "to be represented inside the model"
                )
305

306
            audio_tokens = [audio_token_id] * num_features
307

308
309
310
            return PromptUpdateDetails.select_token_id(
                [audio_bos_id] + audio_tokens + [audio_eos_id],
                embed_token_id=audio_token_id,
311
            )
312
313
314
315

        return [
            PromptReplacement(
                modality="audio",
316
                target=audio_token,
317
318
                replacement=get_replacement_qwen2_audio,
            )
319
        ]
320
321


322
323
324
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
325
326
327
    dummy_inputs=Qwen2AudioDummyInputsBuilder,
)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
328
    @classmethod
329
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
330
331
332
333
334
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

335
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336
        super().__init__()
337
338
339
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
340
341
342
343
        self.config = config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config

344
345
346
347
348
349
350
351
352
353
354
355
356
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = Qwen2AudioEncoder(config.audio_config)
            self.multi_modal_projector = Qwen2AudioMultiModalProjector(
                config.audio_config.d_model, config.text_config.hidden_size
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
357
358

        self.make_empty_intermediate_tensors = (
359
360
            self.language_model.make_empty_intermediate_tensors
        )
361
362

    def _parse_and_validate_audio_input(
363
        self, **kwargs: object
364
    ) -> Qwen2AudioInputs | None:
365
366
367
        input_features = kwargs.pop("input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
        feature_attention_mask = kwargs.pop("feature_attention_mask", None)
368
369

        if input_features is None and audio_embeds is None:
370
            return None
371
372

        if audio_embeds is not None:
373
374
375
            return Qwen2AudioEmbeddingInputs(
                type="audio_embeds", audio_embeds=audio_embeds
            )
376
377
378
379
380

        if input_features is not None:
            return Qwen2AudioFeatureInputs(
                type="audio_features",
                input_features=input_features,
381
382
                feature_attention_mask=feature_attention_mask,
            )
383
384
385
386
387

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: Qwen2AudioInputs
388
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
389
390
391
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)
392
393
394
395
396
397

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
398
399
400
                feature_attention_mask.sum(-1)
            )
        )
401
402
403
404

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
405
406
407
408
409
410
411
412
413
414
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feat_lengths.dtype,
                device=audio_feat_lengths.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
415
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
416
417
            batch_size, max_seq_len
        )
418
419
420
        # Create mask
        padding_mask = seq_range >= lengths_expand

421
422
423
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
424
425
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
426
427
            device=self.audio_tower.conv1.weight.device,
        )
428
429
        audio_attention_mask[audio_attention_mask_] = float("-inf")

430
431
432
        audio_outputs = self.audio_tower(
            input_features, attention_mask=audio_attention_mask
        )
433
434
435
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
436
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
437
438
439
440
441
442
443
        audio_features_mask = (
            torch.arange(max_audio_tokens)
            .expand(num_audios, max_audio_tokens)
            .to(audio_output_lengths.device)
            < audio_output_lengths
        )
        masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
444

445
        # Split to tuple of embeddings for individual audio input.
446
447
448
        return torch.split(
            masked_audio_features, audio_output_lengths.flatten().tolist()
        )
449

450
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
451
452
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
453
            return []
454
455
456
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

457
458
    def forward(
        self,
459
        input_ids: torch.Tensor | None,
460
        positions: torch.Tensor,
461
462
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
463
        **kwargs: object,
464
    ) -> torch.Tensor | IntermediateTensors:
465
466
        if intermediate_tensors is not None:
            inputs_embeds = None
467

468
469
470
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
471
472
        return hidden_states

473
474
475
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
476
    ) -> torch.Tensor | None:
477
        return self.language_model.compute_logits(hidden_states)
478

479
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
480
481
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)