qwen2_audio.py 19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
from collections.abc import Iterable, Mapping, Sequence
26
from typing import Annotated, Any, Literal, Optional, Union
27
28
29

import torch
import torch.nn as nn
30
from transformers import BatchFeature
31
32
33
34
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
                                             Qwen2AudioEncoder,
                                             Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor
35

36
from vllm.config import VllmConfig
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
from vllm.multimodal import MULTIMODAL_REGISTRY
39
40
from vllm.multimodal.inputs import (AudioItem, ModalityData,
                                    MultiModalDataDict, MultiModalFieldConfig,
41
                                    MultiModalKwargsItems)
42
43
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
                                   ModalityDataItems, MultiModalDataItems,
44
                                   MultiModalDataParser)
45
from vllm.multimodal.processing import (BaseMultiModalProcessor,
46
                                        BaseProcessingInfo, PromptReplacement,
47
                                        PromptUpdate, PromptUpdateDetails)
48
from vllm.multimodal.profiling import BaseDummyInputsBuilder
49
from vllm.sequence import IntermediateTensors
50
from vllm.utils.tensor_schema import TensorSchema, TensorShape
51

52
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
53
54
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
                    maybe_prefix, merge_multimodal_embeddings)
55
56
57


# # === Audio Inputs === #
58
59
60
61
62
63
class Qwen2AudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
    """
64
    type: Literal["audio_features"]
65
66
67
68
    input_features: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("na", "nmb", 3000),
    ]
69

70
71
72
73
    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", 3000),
    ]
74
75


76
class Qwen2AudioEmbeddingInputs(TensorSchema):
77
    """
78
79
80
81
82
83
84
85
86
87
88
89
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
        TensorShape("bn", "naf", "hs"),
    ]
90
91
92
93


Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]

94
95
96
97
98
99
100
101
102
103
104
105
106
107
# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):

    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


108
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
109
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
110
111
112
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
113
114


115
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
116

117
    def get_hf_config(self):
118
119
        return self.ctx.get_hf_config(Qwen2AudioConfig)

120
    def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
121
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
122

123
124
125
    def get_feature_extractor(self,
                              **kwargs: object) -> WhisperFeatureExtractor:
        hf_processor = self.get_hf_processor(**kwargs)
126
127
128
129
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

130
131
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
132

133
134
135
136

class Qwen2AudioDummyInputsBuilder(
        BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):

137
138
139
140
141
142
143
144
145
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token

        return audio_token * num_audios

    def get_dummy_mm_data(
146
        self,
147
148
        seq_len: int,
        mm_counts: Mapping[str, int],
149
    ) -> MultiModalDataDict:
150
        feature_extractor = self.info.get_feature_extractor()
151
152
153
154
155

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

156
        return {
157
158
159
160
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        audio_embeds=MultiModalFieldConfig.batched("audio"),
        input_features=MultiModalFieldConfig.batched("audio"),
        feature_attention_mask=MultiModalFieldConfig.batched("audio"),
    )


class Qwen2AudioMultiModalDataParser(MultiModalDataParser):

    def _parse_audio_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
    ) -> Optional[ModalityDataItems[Any, Any]]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_qwen2audio_field_config,
            )

        return super()._parse_audio_data(data)


187
188
class Qwen2AudioMultiModalProcessor(
        BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
189

190
    def _get_data_parser(self) -> MultiModalDataParser:
191
        feature_extractor = self.info.get_feature_extractor()
192
193
        return Qwen2AudioMultiModalDataParser(
            target_sr=feature_extractor.sampling_rate)
194

195
196
197
    def _call_hf_processor(
        self,
        prompt: str,
198
        mm_data: Mapping[str, object],
199
        mm_kwargs: Mapping[str, Any],
200
        tok_kwargs: Mapping[str, object],
201
    ) -> BatchFeature:
202
203
204
205
206
207
208
        # NOTE - we rename audios -> audio in mm data because transformers has
        # deprecated audios for the qwen2audio processor and will remove
        # support for it in transformers 4.54.
        audios = mm_data.pop("audios", [])
        if audios:
            mm_data["audio"] = audios

209
        # Text-only input not supported in composite processor
210
        if not mm_data.get("audio", []):
211
212
213
214
215
216
217
218
219
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
220

221
        return super()._call_hf_processor(
222
            prompt=prompt,
223
224
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
225
            tok_kwargs=tok_kwargs,
226
227
228
229
230
231
232
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
233
        return _qwen2audio_field_config(hf_inputs)
234

235
    def _get_prompt_updates(
236
237
        self,
        mm_items: MultiModalDataItems,
238
        hf_processor_mm_kwargs: Mapping[str, object],
239
        out_mm_kwargs: MultiModalKwargsItems,
240
    ) -> Sequence[PromptUpdate]:
241

242
243
244
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
245
246
247
248
249
250
251

        # Use getattr with default to be compatible with transformers<4.48
        audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
        audio_bos_token = getattr(processor, "audio_bos_token",
                                  "<|audio_bos|>")
        audio_eos_token = getattr(processor, "audio_eos_token",
                                  "<|audio_eos|>")
252

253
254
255
256
        audio_token_id = vocab[audio_token]
        audio_bos_id = vocab[audio_bos_token]
        audio_eos_id = vocab[audio_eos_token]

257
258
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
259
260
261
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
262
            assert isinstance(feature_attention_mask, torch.Tensor)
263
            _, audio_output_lens = _get_feat_extract_output_lengths(
264
265
                feature_attention_mask.sum(-1))

266
267
            audio_output_lengths = audio_output_lens.tolist()

268
        def get_replacement_qwen2_audio(item_idx: int):
269
270
271
272
273
274
275
276
277

            if audio_output_lengths:
                num_features = audio_output_lengths[item_idx]
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
                assert len(audio_embeds.shape
                           ) == 2, "audio_embeds must be a 2D tensor"
                num_features = audio_embeds.shape[0]

278
            if num_features == 0:
279
                audios = mm_items.get_items("audio", AudioProcessorItems)
280
281
282
283
                audio_len = audios.get_audio_length(item_idx)

                raise ValueError(f"The audio (len={audio_len}) is too short "
                                 "to be represented inside the model")
284

285
            audio_tokens = [audio_token_id] * num_features
286

287
288
289
            return PromptUpdateDetails.select_token_id(
                [audio_bos_id] + audio_tokens + [audio_eos_id],
                embed_token_id=audio_token_id,
290
            )
291
292
293
294

        return [
            PromptReplacement(
                modality="audio",
295
                target=audio_token,
296
297
                replacement=get_replacement_qwen2_audio,
            )
298
        ]
299
300


301
302
303
304
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
    dummy_inputs=Qwen2AudioDummyInputsBuilder)
305
306
307
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
                                         SupportsPP):

308
309
310
311
312
313
314
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

315
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
316
        super().__init__()
317
318
319
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
320
321
322
323
324
325
326
327
328
        self.config = config
        self.multimodal_config = multimodal_config

        self.audio_tower = Qwen2AudioEncoder(config.audio_config)
        self.multi_modal_projector = Qwen2AudioMultiModalProjector(
            config.audio_config.d_model, config.text_config.hidden_size)

        self.quant_config = quant_config

329
330
331
332
333
334
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
335
336
337
338

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

339
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
340
341
342
343
344
345
346
347
348
349
350
351
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            return torch.concat(list(mm_input))
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
        input_features = kwargs.pop('input_features', None)
352
        audio_embeds = kwargs.pop('audio_embeds', None)
353
        feature_attention_mask = kwargs.pop('feature_attention_mask', None)
354
355

        if input_features is None and audio_embeds is None:
356
            return None
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

        if audio_embeds is not None:
            if not isinstance(audio_embeds, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")
            audio_embeds = self._validate_and_reshape_mm_tensor(
                audio_embeds, "audio_embeds")
            return Qwen2AudioEmbeddingInputs(type="audio_embeds",
                                             audio_embeds=audio_embeds)

        if input_features is not None:
            input_features = self._validate_and_reshape_mm_tensor(
                input_features, 'input_features')
            feature_attention_mask = self._validate_and_reshape_mm_tensor(
                feature_attention_mask, 'feature_attention_mask')
            return Qwen2AudioFeatureInputs(
                type="audio_features",
                input_features=input_features,
                feature_attention_mask=feature_attention_mask)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: Qwen2AudioInputs
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)))

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
        seq_range = (torch.arange(
            0,
            max_seq_len,
            dtype=audio_feat_lengths.dtype,
            device=audio_feat_lengths.device).unsqueeze(0).expand(
                batch_size, max_seq_len))
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
            batch_size, max_seq_len)
        # Create mask
        padding_mask = seq_range >= lengths_expand

        audio_attention_mask_ = padding_mask.view(
            batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
                                                  max_seq_len)
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
            device=self.audio_tower.conv1.weight.device)
        audio_attention_mask[audio_attention_mask_] = float("-inf")

        audio_outputs = self.audio_tower(input_features,
                                         attention_mask=audio_attention_mask)
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
420
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
421
        audio_features_mask = torch.arange(max_audio_tokens).expand(
422
423
            num_audios, max_audio_tokens).to(
                audio_output_lengths.device) < audio_output_lengths
424
425
426
        masked_audio_features = audio_features[audio_features_mask].view(
            -1, embed_dim)

427
428
429
        # Split to tuple of embeddings for individual audio input.
        return torch.split(masked_audio_features,
                           audio_output_lengths.flatten().tolist())
430

431
432
433
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

434
435
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
436
437
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
438
            return []
439
440
441
442
443
444
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
445
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
446
447
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
448
449
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
450
451
452
453
454
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.audio_token_index)
        return inputs_embeds

455
456
457
458
459
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
460
        inputs_embeds: Optional[torch.Tensor] = None,
461
462
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
463

464
465
        if intermediate_tensors is not None:
            inputs_embeds = None
466
467
468
469
470
471
472
473
474

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings)
            input_ids = None

475
476
477
478
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
479
480
        return hidden_states

481
482
483
484
485
486
487
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
488

489
490
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
491
492
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)