"examples/multimodal/configs/agg-llava.yaml" did not exist on "4b1867c53ebbf98dea54623af24d2424ead56573"
qwen2_audio.py 16.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable, Mapping, Sequence
27
from typing import Annotated, Any, Literal, TypeAlias
28
29
30

import torch
import torch.nn as nn
31
from transformers import BatchFeature
32
33
34
35
36
from transformers.models.qwen2_audio import (
    Qwen2AudioConfig,
    Qwen2AudioEncoder,
    Qwen2AudioProcessor,
)
37
from transformers.models.whisper import WhisperFeatureExtractor
38

39
from vllm.config import VllmConfig
40
from vllm.config.multimodal import BaseDummyOptions
41
from vllm.multimodal import MULTIMODAL_REGISTRY
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.multimodal.inputs import (
    AudioItem,
    ModalityData,
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    AudioProcessorItems,
    DictEmbeddingItems,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
57
    BaseDummyInputsBuilder,
58
59
60
61
62
63
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
64
from vllm.sequence import IntermediateTensors
65
from vllm.utils.tensor_schema import TensorSchema, TensorShape
66

67
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
68
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
69
70
71


# # === Audio Inputs === #
72
73
74
75
76
77
class Qwen2AudioFeatureInputs(TensorSchema):
    """
    Dimensions:
        - na: Number of audios
        - nmb: Number of mel bins
    """
78

79
    type: Literal["audio_features"]
80
    input_features: Annotated[
81
        torch.Tensor | list[torch.Tensor],
82
83
        TensorShape("na", "nmb", 3000),
    ]
84

85
86
87
88
    feature_attention_mask: Annotated[
        torch.Tensor,
        TensorShape("na", 3000),
    ]
89
90


91
class Qwen2AudioEmbeddingInputs(TensorSchema):
92
    """
93
94
95
96
97
98
    Dimensions:
        - bn: Batch size
        - naf: Number of audio features
        - hs: Hidden size (must match the hidden size of language model
          backbone)
    """
99

100
101
102
103
    type: Literal["audio_embeds"] = "audio_embeds"

    audio_embeds: Annotated[
        list[torch.Tensor],
104
        TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
105
    ]
106
107


108
Qwen2AudioInputs: TypeAlias = Qwen2AudioFeatureInputs | Qwen2AudioEmbeddingInputs
109

110
111
112
113
114
115
116
117
118
119
120
121
122
# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):
    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


123
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
124
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
125
126
127
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
128
129


130
131
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
132
133
        return self.ctx.get_hf_config(Qwen2AudioConfig)

134
    def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
135
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
136

137
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
138
        hf_processor = self.get_hf_processor(**kwargs)
139
140
141
142
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

143
144
145
146
    def get_target_channels(self) -> int:
        """Return target audio channels for Qwen2 Audio models (mono)."""
        return 1

147
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
148
        return {"audio": None}
149

150

151
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
152
153
154
155
156
157
158
159
160
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        hf_processor = self.info.get_hf_processor()
        audio_token = hf_processor.audio_token

        return audio_token * num_audios

    def get_dummy_mm_data(
161
        self,
162
163
        seq_len: int,
        mm_counts: Mapping[str, int],
164
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
165
    ) -> MultiModalDataDict:
166
        feature_extractor = self.info.get_feature_extractor()
167
168
169
170
171

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

172
173
        audio_overrides = mm_options.get("audio") if mm_options else None

174
        return {
175
176
177
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
178
179
        }

180

181
182
183
184
185
186
187
188
189
190
191
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    return dict(
        audio_embeds=MultiModalFieldConfig.batched("audio"),
        input_features=MultiModalFieldConfig.batched("audio"),
        feature_attention_mask=MultiModalFieldConfig.batched("audio"),
    )


class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
    def _parse_audio_data(
        self,
192
193
        data: dict[str, torch.Tensor] | ModalityData[AudioItem],
    ) -> ModalityDataItems[Any, Any] | None:
194
195
196
197
198
199
200
201
202
203
204
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_qwen2audio_field_config,
            )

        return super()._parse_audio_data(data)


205
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
206
    def _get_data_parser(self) -> MultiModalDataParser:
207
        feature_extractor = self.info.get_feature_extractor()
208
209
210
211
        return Qwen2AudioMultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.info.get_target_channels(),
        )
212

213
214
215
    def _call_hf_processor(
        self,
        prompt: str,
216
        mm_data: Mapping[str, object],
217
        mm_kwargs: Mapping[str, Any],
218
        tok_kwargs: Mapping[str, object],
219
    ) -> BatchFeature:
220
221
222
223
224
225
226
        # NOTE - we rename audios -> audio in mm data because transformers has
        # deprecated audios for the qwen2audio processor and will remove
        # support for it in transformers 4.54.
        audios = mm_data.pop("audios", [])
        if audios:
            mm_data["audio"] = audios

227
        # Text-only input not supported in composite processor
228
        if not mm_data.get("audio", []):
229
230
231
232
233
234
235
236
237
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
238

239
        return super()._call_hf_processor(
240
            prompt=prompt,
241
242
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
243
            tok_kwargs=tok_kwargs,
244
245
246
247
248
249
250
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
251
        return _qwen2audio_field_config(hf_inputs)
252

253
    def _get_prompt_updates(
254
255
        self,
        mm_items: MultiModalDataItems,
256
        hf_processor_mm_kwargs: Mapping[str, object],
257
        out_mm_kwargs: MultiModalKwargsItems,
258
    ) -> Sequence[PromptUpdate]:
259
260
261
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
262
263
264

        # Use getattr with default to be compatible with transformers<4.48
        audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
265
266
        audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
        audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
267

268
269
270
271
        audio_token_id = vocab[audio_token]
        audio_bos_id = vocab[audio_bos_token]
        audio_eos_id = vocab[audio_eos_token]

272
273
        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
274
275
276
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
277
            assert isinstance(feature_attention_mask, torch.Tensor)
278
            _, audio_output_lens = _get_feat_extract_output_lengths(
279
280
                feature_attention_mask.sum(-1)
            )
281

282
283
            audio_output_lengths = audio_output_lens.tolist()

284
        def get_replacement_qwen2_audio(item_idx: int):
285
286
287
288
            if audio_output_lengths:
                num_features = audio_output_lengths[item_idx]
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
289
                assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
290
291
                num_features = audio_embeds.shape[0]

292
            if num_features == 0:
293
                audios = mm_items.get_items("audio", AudioProcessorItems)
294
295
                audio_len = audios.get_audio_length(item_idx)

296
297
298
299
                raise ValueError(
                    f"The audio (len={audio_len}) is too short "
                    "to be represented inside the model"
                )
300

301
            audio_tokens = [audio_token_id] * num_features
302

303
304
305
            return PromptUpdateDetails.select_token_id(
                [audio_bos_id] + audio_tokens + [audio_eos_id],
                embed_token_id=audio_token_id,
306
            )
307
308
309
310

        return [
            PromptReplacement(
                modality="audio",
311
                target=audio_token,
312
313
                replacement=get_replacement_qwen2_audio,
            )
314
        ]
315
316


317
318
319
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
320
321
322
    dummy_inputs=Qwen2AudioDummyInputsBuilder,
)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
323
    @classmethod
324
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
325
326
327
328
329
        if modality.startswith("audio"):
            return f"Audio {i}: <|audio_bos|><|AUDIO|><|audio_eos|>"

        raise ValueError("Only audio modality is supported")

330
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
331
        super().__init__()
332
333
334
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
335
336
337
338
339
        self.config = config
        self.multimodal_config = multimodal_config

        self.audio_tower = Qwen2AudioEncoder(config.audio_config)
        self.multi_modal_projector = Qwen2AudioMultiModalProjector(
340
341
            config.audio_config.d_model, config.text_config.hidden_size
        )
342
343
344

        self.quant_config = quant_config

345
346
347
348
349
350
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
351
352

        self.make_empty_intermediate_tensors = (
353
354
            self.language_model.make_empty_intermediate_tensors
        )
355
356

    def _parse_and_validate_audio_input(
357
        self, **kwargs: object
358
    ) -> Qwen2AudioInputs | None:
359
360
361
        input_features = kwargs.pop("input_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
        feature_attention_mask = kwargs.pop("feature_attention_mask", None)
362
363

        if input_features is None and audio_embeds is None:
364
            return None
365
366

        if audio_embeds is not None:
367
368
369
            return Qwen2AudioEmbeddingInputs(
                type="audio_embeds", audio_embeds=audio_embeds
            )
370
371
372
373
374

        if input_features is not None:
            return Qwen2AudioFeatureInputs(
                type="audio_features",
                input_features=input_features,
375
376
                feature_attention_mask=feature_attention_mask,
            )
377
378
379
380
381

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: Qwen2AudioInputs
382
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
383
384
385
        if audio_input["type"] == "audio_embeds":
            audio_embeds = audio_input["audio_embeds"]
            return tuple(audio_embeds)
386
387
388
389
390
391

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
392
393
394
                feature_attention_mask.sum(-1)
            )
        )
395
396
397
398

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
399
400
401
402
403
404
405
406
407
408
        seq_range = (
            torch.arange(
                0,
                max_seq_len,
                dtype=audio_feat_lengths.dtype,
                device=audio_feat_lengths.device,
            )
            .unsqueeze(0)
            .expand(batch_size, max_seq_len)
        )
409
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
410
411
            batch_size, max_seq_len
        )
412
413
414
        # Create mask
        padding_mask = seq_range >= lengths_expand

415
416
417
        audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
            batch_size, 1, max_seq_len, max_seq_len
        )
418
419
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
420
421
            device=self.audio_tower.conv1.weight.device,
        )
422
423
        audio_attention_mask[audio_attention_mask_] = float("-inf")

424
425
426
        audio_outputs = self.audio_tower(
            input_features, attention_mask=audio_attention_mask
        )
427
428
429
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
430
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
431
432
433
434
435
436
437
        audio_features_mask = (
            torch.arange(max_audio_tokens)
            .expand(num_audios, max_audio_tokens)
            .to(audio_output_lengths.device)
            < audio_output_lengths
        )
        masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
438

439
        # Split to tuple of embeddings for individual audio input.
440
441
442
        return torch.split(
            masked_audio_features, audio_output_lengths.flatten().tolist()
        )
443

444
445
446
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

447
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
448
449
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
450
            return []
451
452
453
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

454
455
456
457
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
458
459
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
460
        **kwargs: object,
461
    ) -> torch.Tensor | IntermediateTensors:
462
463
        if intermediate_tensors is not None:
            inputs_embeds = None
464

465
466
467
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
468
469
        return hidden_states

470
471
472
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
473
    ) -> torch.Tensor | None:
474
        return self.language_model.compute_logits(hidden_states)
475

476
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
477
478
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)