qwen2_audio.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable, Mapping, Sequence
27
from typing import Annotated, Any, Literal, TypeAlias
28
29
30

import torch
import torch.nn as nn
31
from transformers import BatchFeature
32
33
34
35
36
from transformers.models.qwen2_audio import (
    Qwen2AudioConfig,
    Qwen2AudioEncoder,
    Qwen2AudioProcessor,
)
37
from transformers.models.whisper import WhisperFeatureExtractor
38

39
from vllm.config import VllmConfig
40
from vllm.config.multimodal import BaseDummyOptions
41
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.multimodal.inputs import (
    AudioItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
57
    BaseDummyInputsBuilder,
58
59
60
61
62
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
63
from vllm.sequence import IntermediateTensors
64
from vllm.utils.tensor_schema import TensorSchema, TensorShape
65

66
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
67
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
68
69
70


# # === Audio Inputs === #
71
72
73
74
75
76
class Qwen2AudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
    """
77

78
    type: Literal["audio_features"]
79
    input_features: Annotated[
80
        torch.Tensor | list[torch.Tensor],
81
82
        TensorShape("na", "nmb", 3000),
    ]
83

84
85
86
87
    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", 3000),
    ]
88
89


90
class Qwen2AudioEmbeddingInputs(TensorSchema):
91
    """
92
93
94
95
96
97
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
98

99
100
101
102
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
103
        TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
104
    ]
105
106


107
Qwen2AudioInputs: TypeAlias = Qwen2AudioFeatureInputs | Qwen2AudioEmbeddingInputs
108

109
110
111
112
113
114
115
116
117
118
119
120
121
# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):
    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


122
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
123
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
124
125
126
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
127
128


129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        audio_embeds=MultiModalFieldConfig.batched("audio"),
        input_features=MultiModalFieldConfig.batched("audio"),
        feature_attention_mask=MultiModalFieldConfig.batched("audio"),
    )


class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
    def _parse_audio_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_qwen2audio_field_config,
            )

        return super()._parse_audio_data(data)


153
154
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
155
156
        return self.ctx.get_hf_config(Qwen2AudioConfig)

157
    def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
158
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
159

160
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
161
        hf_processor = self.get_hf_processor(**kwargs)
162
163
164
165
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

166
167
168
169
170
171
172
173
174
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return Qwen2AudioMultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

175
176
177
178
    def get_target_channels(self) -> int:
        """Return target audio channels for Qwen2 Audio models (mono)."""
        return 1

179
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
180
        return {"audio": None}
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int] | None = None,
    ) -> Mapping[str, int]:
        mm_counts = mm_counts or {}
        if mm_counts.get("audio", 0) <= 0:
            return {}

        feature_extractor = self.get_feature_extractor()
        chunk_length = min(feature_extractor.chunk_length, 30)
        audio_len = int(chunk_length * feature_extractor.sampling_rate)
        hop_length = feature_extractor.hop_length
        max_mel_seq_len = audio_len // hop_length

        input_lengths = torch.tensor([max_mel_seq_len], dtype=torch.long)
        _, output_lengths = _get_feat_extract_output_lengths(input_lengths)

        return {"audio": int(output_lengths.item())}

202

203
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
204
205
206
207
208
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token
209
210
        audio_bos_token = hf_processor.audio_bos_token
        audio_eos_token = hf_processor.audio_eos_token
211

212
        return (audio_bos_token + audio_token + audio_eos_token) * num_audios
213
214

    def get_dummy_mm_data(
215
        self,
216
217
        seq_len: int,
        mm_counts: Mapping[str, int],
218
        mm_options: Mapping[str, BaseDummyOptions],
219
    ) -> MultiModalDataDict:
220
        feature_extractor = self.info.get_feature_extractor()
221
222
223
224
225

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

226
        audio_overrides = mm_options.get("audio")
227

228
        return {
229
            "audio": self._get_dummy_audios(
230
231
232
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
233
            )
234
235
        }

236

237
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
238
239
240
    def _call_hf_processor(
        self,
        prompt: str,
241
        mm_data: Mapping[str, object],
242
        mm_kwargs: Mapping[str, Any],
243
        tok_kwargs: Mapping[str, object],
244
    ) -> BatchFeature:
245
246
247
248
249
250
251
        # NOTE - we rename audios -> audio in mm data because transformers has
        # deprecated audios for the qwen2audio processor and will remove
        # support for it in transformers 4.54.
        audios = mm_data.pop("audios", [])
        if audios:
            mm_data["audio"] = audios

252
        # Text-only input not supported in composite processor
253
        if not mm_data.get("audio", []):
254
255
256
257
258
259
260
261
262
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
263

264
        return super()._call_hf_processor(
265
            prompt=prompt,
266
267
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
268
            tok_kwargs=tok_kwargs,
269
270
271
272
273
274
275
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
276
        return _qwen2audio_field_config(hf_inputs)
277

278
    def _get_prompt_updates(
279
280
        self,
        mm_items: MultiModalDataItems,
281
        hf_processor_mm_kwargs: Mapping[str, object],
282
        out_mm_kwargs: MultiModalKwargsItems,
283
    ) -> Sequence[PromptUpdate]:
284
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
285
        audio_token_id = processor.audio_token_id
286

287
288
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
289
290
291
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
292
            assert isinstance(feature_attention_mask, torch.Tensor)
293
            _, audio_output_lens = _get_feat_extract_output_lengths(
294
295
                feature_attention_mask.sum(-1)
            )
296

297
298
            audio_output_lengths = audio_output_lens.tolist()

299
        def get_replacement_qwen2_audio(item_idx: int):
300
301
302
303
            if audio_output_lengths:
                num_features = audio_output_lengths[item_idx]
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
304
                assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
305
306
                num_features = audio_embeds.shape[0]

307
            if num_features == 0:
308
                audios = mm_items.get_items("audio", AudioProcessorItems)
309
310
                audio_len = audios.get_audio_length(item_idx)

311
312
313
314
                raise ValueError(
                    f"The audio (len={audio_len}) is too short "
                    "to be represented inside the model"
                )
315

316
            return [audio_token_id] * num_features
317
318
319
320

        return [
            PromptReplacement(
                modality="audio",
321
                target=[audio_token_id],
322
323
                replacement=get_replacement_qwen2_audio,
            )
324
        ]
325
326


327
328
329
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
330
331
332
    dummy_inputs=Qwen2AudioDummyInputsBuilder,
)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
333
    @classmethod
334
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
335
336
337
338
339
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

340
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
341
        super().__init__()
342
343
344
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
345
346
347
348
        self.config = config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config

349
350
351
352
353
354
355
356
357
358
359
360
361
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = Qwen2AudioEncoder(config.audio_config)
            self.multi_modal_projector = Qwen2AudioMultiModalProjector(
                config.audio_config.d_model, config.text_config.hidden_size
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
362
363

        self.make_empty_intermediate_tensors = (
364
365
            self.language_model.make_empty_intermediate_tensors
        )
366
367

    def _parse_and_validate_audio_input(
368
        self, **kwargs: object
369
    ) -> Qwen2AudioInputs | None:
370
371
372
        input_features = kwargs.pop("input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
        feature_attention_mask = kwargs.pop("feature_attention_mask", None)
373
374

        if input_features is None and audio_embeds is None:
375
            return None
376
377

        if audio_embeds is not None:
378
379
380
            return Qwen2AudioEmbeddingInputs(
                type="audio_embeds", audio_embeds=audio_embeds
            )
381
382
383
384
385

        if input_features is not None:
            return Qwen2AudioFeatureInputs(
                type="audio_features",
                input_features=input_features,
386
387
                feature_attention_mask=feature_attention_mask,
            )
388
389
390
391
392

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: Qwen2AudioInputs
393
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
394
395
396
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)
397
398
399
400
401
402

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
403
404
405
                feature_attention_mask.sum(-1)
            )
        )
406
407
408
409

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
410
411
412
413
414
415
416
417
418
419
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feat_lengths.dtype,
                device=audio_feat_lengths.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
420
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
421
422
            batch_size, max_seq_len
        )
423
424
425
        # Create mask
        padding_mask = seq_range >= lengths_expand

426
427
428
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
429
430
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
431
432
            device=self.audio_tower.conv1.weight.device,
        )
433
434
        audio_attention_mask[audio_attention_mask_] = float("-inf")

435
436
437
        audio_outputs = self.audio_tower(
            input_features, attention_mask=audio_attention_mask
        )
438
439
440
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
441
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
442
443
444
445
446
447
448
        audio_features_mask = (
            torch.arange(max_audio_tokens)
            .expand(num_audios, max_audio_tokens)
            .to(audio_output_lengths.device)
            < audio_output_lengths
        )
        masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
449

450
        # Split to tuple of embeddings for individual audio input.
451
452
453
        return torch.split(
            masked_audio_features, audio_output_lengths.flatten().tolist()
        )
454

455
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
456
457
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
458
            return []
459
460
461
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

462
463
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
464
        input_ids: torch.Tensor | None,
465
        positions: torch.Tensor,
466
467
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
468
        **kwargs: object,
469
    ) -> torch.Tensor | IntermediateTensors:
470
471
        if intermediate_tensors is not None:
            inputs_embeds = None
472

473
474
475
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
476
477
        return hidden_states

478
479
480
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
481
    ) -> torch.Tensor | None:
482
        return self.language_model.compute_logits(hidden_states)
483

484
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
485
        loader = AutoWeightsLoader(self)
zhuwenwen's avatar
zhuwenwen committed
486
        return loader.load_weights(weights)