qwen2_audio.py 16.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
24
from collections.abc import Iterable, Mapping, Sequence
25
from functools import cached_property
26
from typing import Any, Optional, Set, Tuple, TypedDict, Union
27
28
29

import torch
import torch.nn as nn
30
from transformers import BatchFeature
31
32
33
34
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
                                             Qwen2AudioEncoder,
                                             Qwen2AudioProcessor)
from transformers.models.whisper import WhisperFeatureExtractor
35

36
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
37
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
38
from vllm.model_executor.sampling_metadata import SamplingMetadata
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
41
42
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
                                   MultiModalDataParser)
43
from vllm.multimodal.processing import (BaseMultiModalProcessor,
44
                                        BaseProcessingInfo, PromptReplacement,
45
                                        PromptUpdate, PromptUpdateDetails)
46
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
47
from vllm.sequence import IntermediateTensors
48

49
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
50
51
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
                    maybe_prefix, merge_multimodal_embeddings)
52
53
54
55
56


# # === Audio Inputs === #
class Qwen2AudioInputs(TypedDict):
    input_features: torch.Tensor
57
    """Shape: `(num_audios, num_mel_bins, 3000)`"""
58
59

    feature_attention_mask: torch.Tensor
60
    """Shape: `(num_audios, 3000)`"""
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


# === Audio Encoder === #


class Qwen2AudioMultiModalProjector(nn.Module):

    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
        super().__init__()
        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)

    def forward(self, audio_features):
        hidden_states = self.linear(audio_features)
        return hidden_states


77
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
78
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
79
80
81
    feat_lengths = (input_lengths - 1) // 2 + 1
    output_lengths = (feat_lengths - 2) // 2 + 1
    return feat_lengths, output_lengths
82
83


84
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
85

86
    def get_hf_config(self):
87
88
        return self.ctx.get_hf_config(Qwen2AudioConfig)

89
    def get_hf_processor(
90
91
92
93
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
94
        **kwargs: object,
95
    ) -> Qwen2AudioProcessor:
96
        return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
97

98
    def get_feature_extractor(
99
100
101
102
103
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> WhisperFeatureExtractor:
104
        hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
105
106
107
108
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

109
110
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
111

112
113
114
115

class Qwen2AudioDummyInputsBuilder(
        BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):

116
    def get_dummy_processor_inputs(
117
        self,
118
119
120
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
121
        feature_extractor = self.info.get_feature_extractor()
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

        mm_data = {
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

        return ProcessorInputs(
            prompt_text="<|AUDIO|>" * num_audios,
            mm_data=mm_data,
        )

137

138
139
class Qwen2AudioMultiModalProcessor(
        BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
140

141
    def _get_data_parser(self) -> MultiModalDataParser:
142
        feature_extractor = self.info.get_feature_extractor()
143
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
144

145
146
147
    def _call_hf_processor(
        self,
        prompt: str,
148
        mm_data: Mapping[str, object],
149
        mm_kwargs: Mapping[str, Any],
150
    ) -> BatchFeature:
151
        # Text-only input not supported in composite processor
152
        if not mm_data.get("audios", []):
153
154
155
156
157
158
159
160
161
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        mm_kwargs = dict(
            **mm_kwargs,
            sampling_rate=feature_extractor.sampling_rate,
        )
162

163
        return super()._call_hf_processor(
164
            prompt=prompt,
165
166
167
168
169
170
171
172
173
174
175
176
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            input_features=MultiModalFieldConfig.batched("audio"),
            feature_attention_mask=MultiModalFieldConfig.batched("audio"),
177
178
        )

179
    def _get_prompt_updates(
180
181
        self,
        mm_items: MultiModalDataItems,
182
183
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
184
    ) -> Sequence[PromptUpdate]:
185
186
187
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
188
189
190
191
192
193
194

        # Use getattr with default to be compatible with transformers<4.48
        audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
        audio_bos_token = getattr(processor, "audio_bos_token",
                                  "<|audio_bos|>")
        audio_eos_token = getattr(processor, "audio_eos_token",
                                  "<|audio_eos|>")
195

196
197
198
199
        audio_token_id = vocab[audio_token]
        audio_bos_id = vocab[audio_bos_token]
        audio_eos_id = vocab[audio_eos_token]

200
        feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
201
202
203
        if feature_attention_mask is None:
            audio_output_lengths = []
        else:
204
            assert isinstance(feature_attention_mask, torch.Tensor)
205
            _, audio_output_lens = _get_feat_extract_output_lengths(
206
207
                feature_attention_mask.sum(-1))

208
209
            audio_output_lengths = audio_output_lens.tolist()

210
        def get_replacement_qwen2_audio(item_idx: int):
211
212
            num_features = audio_output_lengths[item_idx]
            if num_features == 0:
213
                audios = mm_items.get_items("audio", AudioProcessorItems)
214
215
216
217
                audio_len = audios.get_audio_length(item_idx)

                raise ValueError(f"The audio (len={audio_len}) is too short "
                                 "to be represented inside the model")
218

219
            audio_tokens = [audio_token_id] * num_features
220

221
222
223
            return PromptUpdateDetails.select_token_id(
                [audio_bos_id] + audio_tokens + [audio_eos_id],
                embed_token_id=audio_token_id,
224
            )
225
226
227
228

        return [
            PromptReplacement(
                modality="audio",
229
                target=audio_token,
230
231
                replacement=get_replacement_qwen2_audio,
            )
232
        ]
233
234


235
236
237
238
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2AudioMultiModalProcessor,
    info=Qwen2AudioProcessingInfo,
    dummy_inputs=Qwen2AudioDummyInputsBuilder)
239
240
241
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
                                         SupportsPP):

242
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
243
        super().__init__()
244
245
246
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
247
248
249
250
251
252
253
254
255
        self.config = config
        self.multimodal_config = multimodal_config

        self.audio_tower = Qwen2AudioEncoder(config.audio_config)
        self.multi_modal_projector = Qwen2AudioMultiModalProjector(
            config.audio_config.d_model, config.text_config.hidden_size)

        self.quant_config = quant_config

256
257
258
259
260
261
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )
262
263
264
265

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

266
267
268
269
270
271
272
    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return get_sampler()

273
    def _validate_and_reshape_mm_tensor(self, mm_input: object,
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
                                        name: str) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. "
                             f"Got type: {type(mm_input)}")
        if isinstance(mm_input, torch.Tensor):
            return torch.concat(list(mm_input))
        else:
            return torch.concat(mm_input)

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
        input_features = kwargs.pop('input_features', None)
        feature_attention_mask = kwargs.pop('feature_attention_mask', None)
        if input_features is None:
            return None
        input_features = self._validate_and_reshape_mm_tensor(
            input_features, 'input_features')
        feature_attention_mask = self._validate_and_reshape_mm_tensor(
            feature_attention_mask, 'feature_attention_mask')
        if not isinstance(input_features, (torch.Tensor, list)):
            raise ValueError("Incorrect type of audio input features. "
                             f"Got type: {type(input_features)}")
        return Qwen2AudioInputs(input_features=input_features,
                                feature_attention_mask=feature_attention_mask)

    def _process_audio_input(self,
                             audio_input: Qwen2AudioInputs) -> torch.Tensor:

        input_features = audio_input["input_features"]
        feature_attention_mask = audio_input["feature_attention_mask"]

        audio_feat_lengths, audio_output_lengths = (
            self.audio_tower._get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)))

        batch_size, _, max_mel_seq_len = input_features.shape
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        # Create a sequence tensor of shape (batch_size, max_seq_len)
        seq_range = (torch.arange(
            0,
            max_seq_len,
            dtype=audio_feat_lengths.dtype,
            device=audio_feat_lengths.device).unsqueeze(0).expand(
                batch_size, max_seq_len))
        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
            batch_size, max_seq_len)
        # Create mask
        padding_mask = seq_range >= lengths_expand

        audio_attention_mask_ = padding_mask.view(
            batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
                                                  max_seq_len)
        audio_attention_mask = audio_attention_mask_.to(
            dtype=self.audio_tower.conv1.weight.dtype,
            device=self.audio_tower.conv1.weight.device)
        audio_attention_mask[audio_attention_mask_] = float("-inf")

        audio_outputs = self.audio_tower(input_features,
                                         attention_mask=audio_attention_mask)
        selected_audio_feature = audio_outputs.last_hidden_state
        audio_features = self.multi_modal_projector(selected_audio_feature)
        num_audios, max_audio_tokens, embed_dim = audio_features.shape
336
        audio_output_lengths = audio_output_lengths.unsqueeze(1)
337
        audio_features_mask = torch.arange(max_audio_tokens).expand(
338
339
            num_audios, max_audio_tokens).to(
                audio_output_lengths.device) < audio_output_lengths
340
341
342
        masked_audio_features = audio_features[audio_features_mask].view(
            -1, embed_dim)

343
344
345
        # Split to tuple of embeddings for individual audio input.
        return torch.split(masked_audio_features,
                           audio_output_lengths.flatten().tolist())
346

347
348
349
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

350
    def get_multimodal_embeddings(
351
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
352
353
354
355
356
357
358
359
360
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return None
        masked_audio_features = self._process_audio_input(audio_input)
        return masked_audio_features

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
361
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
362
363
364
365
366
367
368
369
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.audio_token_index)
        return inputs_embeds

370
371
372
373
374
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
375
        inputs_embeds: Optional[torch.Tensor] = None,
376
377
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
378

379
380
        if intermediate_tensors is not None:
            inputs_embeds = None
381
382
383
384
385
386
387
388
389

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings)
            input_ids = None

390
391
392
393
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
394
395
        return hidden_states

396
397
398
399
400
401
402
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
403
404
405
406
407
408

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
409
        return self.language_model.sample(logits, sampling_metadata)
410

411
412
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
413
414
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)