qwen2_audio.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable, Mapping, Sequence
27
from typing import Annotated, Any, Literal, TypeAlias
28
29
30

import torch
import torch.nn as nn
31
from transformers import BatchFeature
32
33
34
35
36
from transformers.models.qwen2_audio import (
    Qwen2AudioConfig,
    Qwen2AudioEncoder,
    Qwen2AudioProcessor,
)
37
from transformers.models.whisper import WhisperFeatureExtractor
38

39
from vllm.config import VllmConfig
40
from vllm.config.multimodal import BaseDummyOptions
41
from vllm.inputs import ModalityData, MultiModalDataDict
42
from vllm.multimodal import MULTIMODAL_REGISTRY
43
44
45
46
47
48
49
50
51
52
53
54
55
from vllm.multimodal.inputs import (
    AudioItem,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
56
    BaseDummyInputsBuilder,
57
58
59
60
61
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
62
from vllm.sequence import IntermediateTensors
63
from vllm.utils.tensor_schema import TensorSchema, TensorShape
64

65
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
66
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
67
68
69


# # === Audio Inputs === #
70
71
72
73
74
75
class Qwen2AudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
    """
76

77
    type: Literal["audio_features"]
78
    input_features: Annotated[
79
        torch.Tensor | list[torch.Tensor],
80
81
        TensorShape("na", "nmb", 3000),
    ]
82

83
84
85
86
    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", 3000),
    ]
87
88


89
class Qwen2AudioEmbeddingInputs(TensorSchema):
90
    """
91
92
93
94
95
96
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
97

98
99
100
101
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
102
        TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
103
    ]
104
105


106
Qwen2AudioInputs: TypeAlias = Qwen2AudioFeatureInputs | Qwen2AudioEmbeddingInputs
107

108
109
110
111
112
113
114
115
116
117
118
119
120
# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):
    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


121
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
122
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
123
124
125
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
126
127


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        audio_embeds=MultiModalFieldConfig.batched("audio"),
        input_features=MultiModalFieldConfig.batched("audio"),
        feature_attention_mask=MultiModalFieldConfig.batched("audio"),
    )


class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
    def _parse_audio_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_qwen2audio_field_config,
            )

        return super()._parse_audio_data(data)


152
153
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
154
155
        return self.ctx.get_hf_config(Qwen2AudioConfig)

156
    def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
157
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
158

159
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
160
        hf_processor = self.get_hf_processor(**kwargs)
161
162
163
164
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

165
166
167
168
169
170
171
172
173
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return Qwen2AudioMultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

174
175
176
177
    def get_target_channels(self) -> int:
        """Return target audio channels for Qwen2 Audio models (mono)."""
        return 1

178
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
179
        return {"audio": None}
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int] | None = None,
    ) -> Mapping[str, int]:
        mm_counts = mm_counts or {}
        if mm_counts.get("audio", 0) <= 0:
            return {}

        feature_extractor = self.get_feature_extractor()
        chunk_length = min(feature_extractor.chunk_length, 30)
        audio_len = int(chunk_length * feature_extractor.sampling_rate)
        hop_length = feature_extractor.hop_length
        max_mel_seq_len = audio_len // hop_length

        input_lengths = torch.tensor([max_mel_seq_len], dtype=torch.long)
        _, output_lengths = _get_feat_extract_output_lengths(input_lengths)

        return {"audio": int(output_lengths.item())}

201

202
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
203
204
205
206
207
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token
208
209
        audio_bos_token = hf_processor.audio_bos_token
        audio_eos_token = hf_processor.audio_eos_token
210

211
        return (audio_bos_token + audio_token + audio_eos_token) * num_audios
212
213

    def get_dummy_mm_data(
214
        self,
215
216
        seq_len: int,
        mm_counts: Mapping[str, int],
217
        mm_options: Mapping[str, BaseDummyOptions],
218
    ) -> MultiModalDataDict:
219
        feature_extractor = self.info.get_feature_extractor()
220
221
222
223
224

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

225
        audio_overrides = mm_options.get("audio")
226

227
        return {
228
            "audio": self._get_dummy_audios(
229
230
231
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
232
            )
233
234
        }

235

236
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
237
238
239
    def _call_hf_processor(
        self,
        prompt: str,
240
        mm_data: Mapping[str, object],
241
        mm_kwargs: Mapping[str, Any],
242
        tok_kwargs: Mapping[str, object],
243
    ) -> BatchFeature:
244
245
246
247
248
249
250
        # NOTE - we rename audios -> audio in mm data because transformers has
        # deprecated audios for the qwen2audio processor and will remove
        # support for it in transformers 4.54.
        audios = mm_data.pop("audios", [])
        if audios:
            mm_data["audio"] = audios

251
        # Text-only input not supported in composite processor
252
        if not mm_data.get("audio", []):
253
254
255
256
257
258
259
260
261
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
262

263
        return super()._call_hf_processor(
264
            prompt=prompt,
265
266
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
267
            tok_kwargs=tok_kwargs,
268
269
270
271
272
273
274
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
275
        return _qwen2audio_field_config(hf_inputs)
276

277
    def _get_prompt_updates(
278
279
        self,
        mm_items: MultiModalDataItems,
280
        hf_processor_mm_kwargs: Mapping[str, object],
281
        out_mm_kwargs: MultiModalKwargsItems,
282
    ) -> Sequence[PromptUpdate]:
283
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
284
        audio_token_id = processor.audio_token_id
285

286
287
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
288
289
290
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
291
            assert isinstance(feature_attention_mask, torch.Tensor)
292
            _, audio_output_lens = _get_feat_extract_output_lengths(
293
294
                feature_attention_mask.sum(-1)
            )
295

296
297
            audio_output_lengths = audio_output_lens.tolist()

298
        def get_replacement_qwen2_audio(item_idx: int):
299
300
301
302
            if audio_output_lengths:
                num_features = audio_output_lengths[item_idx]
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
303
                assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
304
305
                num_features = audio_embeds.shape[0]

306
            if num_features == 0:
307
                audios = mm_items.get_items("audio", AudioProcessorItems)
308
309
                audio_len = audios.get_audio_length(item_idx)

310
311
312
313
                raise ValueError(
                    f"The audio (len={audio_len}) is too short "
                    "to be represented inside the model"
                )
314

315
            return [audio_token_id] * num_features
316
317
318
319

        return [
            PromptReplacement(
                modality="audio",
320
                target=[audio_token_id],
321
322
                replacement=get_replacement_qwen2_audio,
            )
323
        ]
324
325


326
327
328
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
329
330
331
    dummy_inputs=Qwen2AudioDummyInputsBuilder,
)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
332
    @classmethod
333
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
334
335
336
337
338
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

339
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340
        super().__init__()
341
342
343
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
344
345
346
347
        self.config = config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config

348
349
350
351
352
353
354
355
356
357
358
359
360
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = Qwen2AudioEncoder(config.audio_config)
            self.multi_modal_projector = Qwen2AudioMultiModalProjector(
                config.audio_config.d_model, config.text_config.hidden_size
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
361
362

        self.make_empty_intermediate_tensors = (
363
364
            self.language_model.make_empty_intermediate_tensors
        )
365
366

    def _parse_and_validate_audio_input(
367
        self, **kwargs: object
368
    ) -> Qwen2AudioInputs | None:
369
370
371
        input_features = kwargs.pop("input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
        feature_attention_mask = kwargs.pop("feature_attention_mask", None)
372
373

        if input_features is None and audio_embeds is None:
374
            return None
375
376

        if audio_embeds is not None:
377
378
379
            return Qwen2AudioEmbeddingInputs(
                type="audio_embeds", audio_embeds=audio_embeds
            )
380
381
382
383
384

        if input_features is not None:
            return Qwen2AudioFeatureInputs(
                type="audio_features",
                input_features=input_features,
385
386
                feature_attention_mask=feature_attention_mask,
            )
387
388
389
390
391

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: Qwen2AudioInputs
392
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
393
394
395
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)
396
397
398
399
400
401

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
402
403
404
                feature_attention_mask.sum(-1)
            )
        )
405
406
407
408

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
409
410
411
412
413
414
415
416
417
418
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feat_lengths.dtype,
                device=audio_feat_lengths.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
419
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
420
421
            batch_size, max_seq_len
        )
422
423
424
        # Create mask
        padding_mask = seq_range >= lengths_expand

425
426
427
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
428
429
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
430
431
            device=self.audio_tower.conv1.weight.device,
        )
432
433
        audio_attention_mask[audio_attention_mask_] = float("-inf")

434
435
436
        audio_outputs = self.audio_tower(
            input_features, attention_mask=audio_attention_mask
        )
437
438
439
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
440
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
441
442
443
444
445
446
447
        audio_features_mask = (
            torch.arange(max_audio_tokens)
            .expand(num_audios, max_audio_tokens)
            .to(audio_output_lengths.device)
            < audio_output_lengths
        )
        masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
448

449
        # Split to tuple of embeddings for individual audio input.
450
451
452
        return torch.split(
            masked_audio_features, audio_output_lengths.flatten().tolist()
        )
453

454
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
455
456
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
457
            return []
458
459
460
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

461
462
    def forward(
        self,
463
        input_ids: torch.Tensor | None,
464
        positions: torch.Tensor,
465
466
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
467
        **kwargs: object,
468
    ) -> torch.Tensor | IntermediateTensors:
469
470
        if intermediate_tensors is not None:
            inputs_embeds = None
471

472
473
474
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
475
476
        return hidden_states

477
478
479
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
480
    ) -> torch.Tensor | None:
481
        return self.language_model.compute_logits(hidden_states)
482

483
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
484
485
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)