"tests/multimodal/test_hasher.py" did not exist on "785d75a03b73a903ff86cd9aa23a3addcdbbd8ab"
ultravox.py 22.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
6
from functools import cached_property
7
8
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
                    Tuple, TypedDict, Union)
9
10
11
12
13

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
14
from transformers import BatchFeature, ProcessorMixin
15
16
17
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

18
from vllm import envs
19
from vllm.attention import AttentionMetadata
20
from vllm.config import VllmConfig
21
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
22
from vllm.model_executor.layers.layernorm import RMSNorm
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
from vllm.model_executor.model_loader.loader import DefaultModelLoader
25
from vllm.model_executor.models.module_mapping import MultiModelKeys
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
29
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
30
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
31
from vllm.multimodal.processing import (BaseMultiModalProcessor,
32
33
                                        BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
34
from vllm.sequence import IntermediateTensors
35
36
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

37
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
38
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
39
                    init_vllm_registered_model, maybe_prefix,
40
                    merge_multimodal_embeddings,
41
                    merge_multimodal_embeddings_from_map)
42

43
44
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
45
46
47
48
49
_AUDIO_TOKENS_PER_SECOND = 6.25


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
50
    data: NestedTensors
51
    """Shape: `(batch_size, num_audios, 80, M)`"""
52
53
54
55


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
56
    data: NestedTensors
57
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
58
59
60
61
62
63


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


64
class UltravoxProcessingInfo(BaseProcessingInfo):
65

66
    def get_hf_processor(
67
68
69
70
71
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> ProcessorMixin:
72
73
74
75
76
77
78
79
        hf_processor = self.ctx.get_hf_processor()

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
        # token, thus we override placeholder with a reserved special
        # token.
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
        return hf_processor
80

81
    def get_feature_extractor(
82
83
84
85
86
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> WhisperFeatureExtractor:
87
        hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
88
89
90
91
92
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

93
94
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
95

96
97
98
99
100
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
101
        feature_extractor = self.get_feature_extractor()
102
103
        max_audio_tokens = math.ceil(feature_extractor.chunk_length *
                                     _AUDIO_TOKENS_PER_SECOND)
104

105
        return {"audio": max_audio_tokens}
106

107
108
109
110

class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
                                 ):

111
    def get_dummy_processor_inputs(
112
        self,
113
114
115
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
116
        feature_extractor = self.info.get_feature_extractor()
117
118
119
120
121
122
123
124
125
126
127
128
129
130

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

        mm_data = {
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

        return ProcessorInputs(
            prompt_text="<|audio|>" * num_audios,
            mm_data=mm_data,
        )
131

132

133
134
class UltravoxMultiModalProcessor(
        BaseMultiModalProcessor[UltravoxProcessingInfo]):
135

136
    def _get_data_parser(self) -> MultiModalDataParser:
137
        feature_extractor = self.info.get_feature_extractor()
138
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
139
140

    def _call_hf_processor(
141
142
        self,
        prompt: str,
143
144
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
145
    ) -> BatchFeature:
146
        # Text-only input not supported in composite processor
147
        if not mm_data or not mm_data.get("audios", []):
148
149
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
150
151
152
153
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
154
        assert isinstance(audios, list)
155

156
        feature_extractor = self.info.get_feature_extractor()
157
158
        mm_kwargs = dict(
            **mm_kwargs,
159
160
161
            sampling_rate=feature_extractor.sampling_rate,
        )

162
163
164
        # Ultravox processor doesn't support multiple inputs,
        # therefore we need to input text and audio one by one
        audio_features, audio_token_len = [], []
165
166
167
        shared_outputs = {}
        for audio in audios:
            # NOTE: Ultravox processor accepts "audio" instead of "audios"
168
            item_processor_data = dict(**mm_data, audio=audio)
169
170
171

            item_outputs = super()._call_hf_processor(
                prompt=prompt,
172
173
                mm_data=item_processor_data,
                mm_kwargs=mm_kwargs,
174
175
176
177
178
179
180
181
            )

            audio_features.append(item_outputs.pop("audio_values")[0])
            audio_token_len.append(item_outputs.pop("audio_token_len").item())
            shared_outputs = item_outputs

        combined_outputs = dict(
            **shared_outputs,
182
183
184
            audio_features=audio_features,
            audio_token_len=audio_token_len,
        )
185
        return BatchFeature(combined_outputs)
186

187
188
189
190
191
192
193
194
195
196
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor omits bos_token_id by setting add_special_tokens=False
        tokenizer = self.info.get_tokenizer()
        assert prompt_tokens[0] == tokenizer.bos_token_id

        return prompt_tokens[1:]

197
198
199
200
201
202
203
204
205
206
207
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            audio_features=MultiModalFieldConfig.batched("audio"),
            audio_token_len=MultiModalFieldConfig.batched("audio"),
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

208
209
210
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
211
        hf_processor_mm_kwargs: Mapping[str, Any],
212
        out_mm_kwargs: MultiModalKwargs,
213
    ) -> list[PromptReplacement]:
214
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
215
216
217
218
219
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        replacement_id = vocab[
            hf_processor.audio_token_replacement]  # type: ignore
220
221

        def get_replacement_ultravox(item_idx: int):
222
            audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
223
            return [replacement_id] * int(audio_token_len)  # type: ignore
224
225
226
227

        return [
            PromptReplacement(
                modality="audio",
228
                target="<|audio|>",
229
230
231
                replacement=get_replacement_ultravox,
            )
        ]
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
261
262
263
264
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
265
266

        if config.projector_act == "swiglu":
267
            self.act = MulAndSilu()
268
            dim_mid = dim_mid // 2
269
270
271
        else:
            self.act = get_act_fn(config.projector_act)

272
273
274
275
276
277
278
279
280
281
282
        dim_out = config.text_config.hidden_size
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
283
284
285
286
287
288

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
289
        hidden_states = self.ln_mid(hidden_states)
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

    def forward(
        self,
        input_features,
    ):
        expected_seq_length = (self.config.max_source_positions *
                               self.conv1.stride[0] * self.conv2.stride[0])
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
                None,
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


351
352
353
354
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
    dummy_inputs=UltravoxDummyInputsBuilder)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

    # LoRA specific attributes
    # TODO : Add LoRA to the audio tower and projector.
    supported_lora_modules = [
        "qkv_proj", "o_proj", "gate_up_proj", "down_proj"
    ]
    embedding_modules = {}
    embedding_padding_modules = []
369

370
371
372
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

373
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
374
        super().__init__()
375
376
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
377
378
379
380
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

381
382
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
383
        if config.audio_model_id is not None:
384
385
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
386
387
388
389
390
391
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
392
393
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
394
            vllm_config=vllm_config,
395
396
397
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
398
        if config.text_model_id is not None:
399
400
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
401
402
403
404
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
405

406
407
408
409
410
411
412
413
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
414
        return get_sampler()
415

416
417
418
419
420
421
422
423
424
425
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    def _audio_features_to_embeddings(
            self, input_features: torch.Tensor) -> torch.Tensor:
        audio_input = input_features.to(self.audio_tower.dtype)
        audio_features = self.audio_tower(audio_input)
        audio_features = audio_features.to(self.audio_tower.dtype)
        audio_embeddings = self.multi_modal_projector(audio_features)
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")

            return UltravoxAudioFeatureInputs(type="audio_features",
                                              data=audio_features)

        if audio_embeds is not None:
451
            if not isinstance(audio_embeds, (torch.Tensor, list)):
452
453
454
455
456
457
458
459
460
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
461
            self, audio_input: UltravoxAudioInputs) -> NestedTensors:
462
463
464
465
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

        audio_features = audio_input["data"]
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
        if isinstance(audio_features, torch.Tensor):
            # Combine the B and N dimensions for the encoder/projector
            flattened = flatten_bn(audio_features)
            flattened_embeddings = self._audio_features_to_embeddings(
                flattened)

            # Restore the original dimensions
            embeddings = flattened_embeddings.unflatten(
                0, audio_features.shape[:2])
            return embeddings

        result = []
        # TODO: Batch heterogeneous tensors through the encoder/projector
        for audio_features_item in audio_features:
            if isinstance(audio_features_item, torch.Tensor):
                result.append(
                    self._audio_features_to_embeddings(audio_features_item))
            else:
                embeddings = [
                    # Add a batch dimension to embed it, then remove it.
                    self._audio_features_to_embeddings(tensor.unsqueeze(0)
                                                       ).squeeze(0)
                    for tensor in audio_features_item
                ]
                result.append(embeddings)

        return result
493

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return None
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:

510
511
512
513
514
515
516
517
518
            # TODO(ywang96): remove this block after v0 is deprecated.
            if not envs.VLLM_USE_V1:
                merge_multimodal_embeddings_from_map(
                    inputs_embeds, multimodal_embeddings,
                    attn_metadata.multi_modal_placeholder_index_maps["audio"])
            else:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, multimodal_embeddings,
                    _AUDIO_PLACEHOLDER_TOKEN)
519
520
521
522
523
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
524
525
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
526
527
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
528
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
529
530
531
532
533
534
535
536
537
538
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
539
            audio_features: A batch of audio inputs [B, N, 80, M].
540
        """
541

542
        if intermediate_tensors is not None:
543
            inputs_embeds = None
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            # TODO(ywang96): remove attn_metadata from get_input_embeddings
            # after v0 is deprecated
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings,
                                                      attn_metadata)
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

577
578
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
579
580
581

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
582
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)