qwen2_audio.py 17.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25
from collections.abc import Iterable, Mapping, Sequence
26
from typing import Annotated, Any, Literal, Optional, Union
27
28
29

import torch
import torch.nn as nn
30
from transformers import BatchFeature
31
32
33
34
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
                                             Qwen2AudioEncoder,
                                             Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor
35

36
from vllm.config import VllmConfig
37
from vllm.multimodal import MULTIMODAL_REGISTRY
38
39
from vllm.multimodal.inputs import (AudioItem, ModalityData,
                                    MultiModalDataDict, MultiModalFieldConfig,
40
                                    MultiModalKwargsItems)
41
42
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
                                   ModalityDataItems, MultiModalDataItems,
43
                                   MultiModalDataParser)
44
from vllm.multimodal.processing import (BaseMultiModalProcessor,
45
                                        BaseProcessingInfo, PromptReplacement,
46
                                        PromptUpdate, PromptUpdateDetails)
47
from vllm.multimodal.profiling import BaseDummyInputsBuilder
48
from vllm.sequence import IntermediateTensors
49
from vllm.utils.tensor_schema import TensorSchema, TensorShape
50

51
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
52
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
53
54
55


# # === Audio Inputs === #
56
57
58
59
60
61
class Qwen2AudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
    """
62
    type: Literal["audio_features"]
63
64
65
66
    input_features: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("na", "nmb", 3000),
    ]
67

68
69
70
71
    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", 3000),
    ]
72
73


74
class Qwen2AudioEmbeddingInputs(TensorSchema):
75
    """
76
77
78
79
80
81
82
83
84
85
86
87
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
        TensorShape("bn", "naf", "hs"),
    ]
88
89
90
91


Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]

92
93
94
95
96
97
98
99
100
101
102
103
104
105
# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):

    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


106
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
107
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
108
109
110
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
111
112


113
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
114

115
    def get_hf_config(self):
116
117
        return self.ctx.get_hf_config(Qwen2AudioConfig)

118
    def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
119
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
120

121
122
123
    def get_feature_extractor(self,
                              **kwargs: object) -> WhisperFeatureExtractor:
        hf_processor = self.get_hf_processor(**kwargs)
124
125
126
127
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

128
129
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
130

131
132
133
134

class Qwen2AudioDummyInputsBuilder(
        BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):

135
136
137
138
139
140
141
142
143
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token

        return audio_token * num_audios

    def get_dummy_mm_data(
144
        self,
145
146
        seq_len: int,
        mm_counts: Mapping[str, int],
147
    ) -> MultiModalDataDict:
148
        feature_extractor = self.info.get_feature_extractor()
149
150
151
152
153

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

154
        return {
155
156
157
158
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        audio_embeds=MultiModalFieldConfig.batched("audio"),
        input_features=MultiModalFieldConfig.batched("audio"),
        feature_attention_mask=MultiModalFieldConfig.batched("audio"),
    )


class Qwen2AudioMultiModalDataParser(MultiModalDataParser):

    def _parse_audio_data(
        self,
        data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
    ) -> Optional[ModalityDataItems[Any, Any]]:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_qwen2audio_field_config,
            )

        return super()._parse_audio_data(data)


185
186
class Qwen2AudioMultiModalProcessor(
        BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
187

188
    def _get_data_parser(self) -> MultiModalDataParser:
189
        feature_extractor = self.info.get_feature_extractor()
190
191
        return Qwen2AudioMultiModalDataParser(
            target_sr=feature_extractor.sampling_rate)
192

193
194
195
    def _call_hf_processor(
        self,
        prompt: str,
196
        mm_data: Mapping[str, object],
197
        mm_kwargs: Mapping[str, Any],
198
        tok_kwargs: Mapping[str, object],
199
    ) -> BatchFeature:
200
201
202
203
204
205
206
        # NOTE - we rename audios -> audio in mm data because transformers has
        # deprecated audios for the qwen2audio processor and will remove
        # support for it in transformers 4.54.
        audios = mm_data.pop("audios", [])
        if audios:
            mm_data["audio"] = audios

207
        # Text-only input not supported in composite processor
208
        if not mm_data.get("audio", []):
209
210
211
212
213
214
215
216
217
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
218

219
        return super()._call_hf_processor(
220
            prompt=prompt,
221
222
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
223
            tok_kwargs=tok_kwargs,
224
225
226
227
228
229
230
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
231
        return _qwen2audio_field_config(hf_inputs)
232

233
    def _get_prompt_updates(
234
235
        self,
        mm_items: MultiModalDataItems,
236
        hf_processor_mm_kwargs: Mapping[str, object],
237
        out_mm_kwargs: MultiModalKwargsItems,
238
    ) -> Sequence[PromptUpdate]:
239

240
241
242
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
243
244
245
246
247
248
249

        # Use getattr with default to be compatible with transformers<4.48
        audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
        audio_bos_token = getattr(processor, "audio_bos_token",
                                  "<|audio_bos|>")
        audio_eos_token = getattr(processor, "audio_eos_token",
                                  "<|audio_eos|>")
250

251
252
253
254
        audio_token_id = vocab[audio_token]
        audio_bos_id = vocab[audio_bos_token]
        audio_eos_id = vocab[audio_eos_token]

255
256
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
257
258
259
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
260
            assert isinstance(feature_attention_mask, torch.Tensor)
261
            _, audio_output_lens = _get_feat_extract_output_lengths(
262
263
                feature_attention_mask.sum(-1))

264
265
            audio_output_lengths = audio_output_lens.tolist()

266
        def get_replacement_qwen2_audio(item_idx: int):
267
268
269
270
271
272
273
274
275

            if audio_output_lengths:
                num_features = audio_output_lengths[item_idx]
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
                assert len(audio_embeds.shape
                           ) == 2, "audio_embeds must be a 2D tensor"
                num_features = audio_embeds.shape[0]

276
            if num_features == 0:
277
                audios = mm_items.get_items("audio", AudioProcessorItems)
278
279
280
281
                audio_len = audios.get_audio_length(item_idx)

                raise ValueError(f"The audio (len={audio_len}) is too short "
                                 "to be represented inside the model")
282

283
            audio_tokens = [audio_token_id] * num_features
284

285
286
287
            return PromptUpdateDetails.select_token_id(
                [audio_bos_id] + audio_tokens + [audio_eos_id],
                embed_token_id=audio_token_id,
288
            )
289
290
291
292

        return [
            PromptReplacement(
                modality="audio",
293
                target=audio_token,
294
295
                replacement=get_replacement_qwen2_audio,
            )
296
        ]
297
298


299
300
301
302
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
    dummy_inputs=Qwen2AudioDummyInputsBuilder)
303
304
305
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
                                         SupportsPP):

306
307
308
309
310
311
312
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

313
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
314
        super().__init__()
315
316
317
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
318
319
320
321
322
323
324
325
326
        self.config = config
        self.multimodal_config = multimodal_config

        self.audio_tower = Qwen2AudioEncoder(config.audio_config)
        self.multi_modal_projector = Qwen2AudioMultiModalProjector(
            config.audio_config.d_model, config.text_config.hidden_size)

        self.quant_config = quant_config

327
328
329
330
331
332
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
333
334
335
336

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

337
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
338
339
340
341
342
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
343
            return mm_input.reshape(-1, *mm_input.shape[2:])
344
345
346
347
348
349
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
        input_features = kwargs.pop('input_features', None)
350
        audio_embeds = kwargs.pop('audio_embeds', None)
351
        feature_attention_mask = kwargs.pop('feature_attention_mask', None)
352
353

        if input_features is None and audio_embeds is None:
354
            return None
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382

        if audio_embeds is not None:
            if not isinstance(audio_embeds, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")
            audio_embeds = self._validate_and_reshape_mm_tensor(
                audio_embeds, "audio_embeds")
            return Qwen2AudioEmbeddingInputs(type="audio_embeds",
                                             audio_embeds=audio_embeds)

        if input_features is not None:
            input_features = self._validate_and_reshape_mm_tensor(
                input_features, 'input_features')
            feature_attention_mask = self._validate_and_reshape_mm_tensor(
                feature_attention_mask, 'feature_attention_mask')
            return Qwen2AudioFeatureInputs(
                type="audio_features",
                input_features=input_features,
                feature_attention_mask=feature_attention_mask)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: Qwen2AudioInputs
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)))

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
        seq_range = (torch.arange(
            0,
            max_seq_len,
            dtype=audio_feat_lengths.dtype,
            device=audio_feat_lengths.device).unsqueeze(0).expand(
                batch_size, max_seq_len))
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
            batch_size, max_seq_len)
        # Create mask
        padding_mask = seq_range >= lengths_expand

        audio_attention_mask_ = padding_mask.view(
            batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
                                                  max_seq_len)
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
            device=self.audio_tower.conv1.weight.device)
        audio_attention_mask[audio_attention_mask_] = float("-inf")

        audio_outputs = self.audio_tower(input_features,
                                         attention_mask=audio_attention_mask)
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
418
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
419
        audio_features_mask = torch.arange(max_audio_tokens).expand(
420
421
            num_audios, max_audio_tokens).to(
                audio_output_lengths.device) < audio_output_lengths
422
423
424
        masked_audio_features = audio_features[audio_features_mask].view(
            -1, embed_dim)

425
426
427
        # Split to tuple of embeddings for individual audio input.
        return torch.split(masked_audio_features,
                           audio_output_lengths.flatten().tolist())
428

429
430
431
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

432
433
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
434
435
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
436
            return []
437
438
439
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

440
441
442
443
444
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
445
        inputs_embeds: Optional[torch.Tensor] = None,
446
447
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
448

449
450
        if intermediate_tensors is not None:
            inputs_embeds = None
451

452
453
454
455
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
456
457
        return hidden_states

458
459
460
461
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
462
        return self.language_model.compute_logits(hidden_states)
463

464
465
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
466
467
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)