ultravox.py 21.2 KB
Newer Older
1
2
3
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
4
from functools import cached_property
5
6
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
                    Tuple, TypedDict, Union)
7
8
9
10
11

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
12
from transformers import BatchFeature, ProcessorMixin
13
14
15
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

16
from vllm import envs
17
from vllm.attention import AttentionMetadata
18
from vllm.config import VllmConfig
19
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
20
from vllm.model_executor.layers.layernorm import RMSNorm
Joe Runde's avatar
Joe Runde committed
21
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
22
from vllm.model_executor.model_loader.loader import DefaultModelLoader
23
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
26
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
27
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
28
from vllm.multimodal.processing import (BaseMultiModalProcessor,
29
30
                                        BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
31
from vllm.sequence import IntermediateTensors
32
33
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

34
from .interfaces import SupportsMultiModal, SupportsPP
35
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
36
                    init_vllm_registered_model, maybe_prefix,
37
                    merge_multimodal_embeddings,
38
                    merge_multimodal_embeddings_from_map)
39

40
41
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
42
43
44
45
46
_AUDIO_TOKENS_PER_SECOND = 6.25


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
47
    data: NestedTensors
48
    """Shape: `(batch_size, num_audios, 80, M)`"""
49
50
51
52


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
53
    data: NestedTensors
54
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
55
56
57
58
59
60


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


61
class UltravoxProcessingInfo(BaseProcessingInfo):
62

63
    def get_hf_processor(
64
65
66
67
68
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> ProcessorMixin:
69
70
71
72
73
74
75
76
        hf_processor = self.ctx.get_hf_processor()

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
        # token, thus we override placeholder with a reserved special
        # token.
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
        return hf_processor
77

78
    def get_feature_extractor(
79
80
81
82
83
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> WhisperFeatureExtractor:
84
        hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
85
86
87
88
89
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

90
91
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
92

93
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
94
        feature_extractor = self.get_feature_extractor()
95
96
        max_audio_tokens = math.ceil(feature_extractor.chunk_length *
                                     _AUDIO_TOKENS_PER_SECOND)
97

98
        return {"audio": max_audio_tokens}
99

100
101
102
103

class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
                                 ):

104
    def get_dummy_processor_inputs(
105
        self,
106
107
108
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
109
        feature_extractor = self.info.get_feature_extractor()
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
        num_audios = mm_counts.get("audio", 0)

        mm_data = {
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

        return ProcessorInputs(
            prompt_text="<|audio|>" * num_audios,
            mm_data=mm_data,
        )
124

125

126
127
class UltravoxMultiModalProcessor(
        BaseMultiModalProcessor[UltravoxProcessingInfo]):
128

129
    def _get_data_parser(self) -> MultiModalDataParser:
130
        feature_extractor = self.info.get_feature_extractor()
131
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
132
133

    def _call_hf_processor(
134
135
        self,
        prompt: str,
136
137
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
138
    ) -> BatchFeature:
139
        # Text-only input not supported in composite processor
140
        if not mm_data or not mm_data.get("audios", []):
141
142
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
143
144
145
146
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
147
        assert isinstance(audios, list)
148

149
        feature_extractor = self.info.get_feature_extractor()
150
151
        mm_kwargs = dict(
            **mm_kwargs,
152
153
154
            sampling_rate=feature_extractor.sampling_rate,
        )

155
156
157
        # Ultravox processor doesn't support multiple inputs,
        # therefore we need to input text and audio one by one
        audio_features, audio_token_len = [], []
158
159
160
        shared_outputs = {}
        for audio in audios:
            # NOTE: Ultravox processor accepts "audio" instead of "audios"
161
            item_processor_data = dict(**mm_data, audio=audio)
162
163
164

            item_outputs = super()._call_hf_processor(
                prompt=prompt,
165
166
                mm_data=item_processor_data,
                mm_kwargs=mm_kwargs,
167
168
169
170
171
172
173
174
            )

            audio_features.append(item_outputs.pop("audio_values")[0])
            audio_token_len.append(item_outputs.pop("audio_token_len").item())
            shared_outputs = item_outputs

        combined_outputs = dict(
            **shared_outputs,
175
176
177
            audio_features=audio_features,
            audio_token_len=audio_token_len,
        )
178
        return BatchFeature(combined_outputs)
179

180
181
182
183
184
185
186
187
188
189
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor omits bos_token_id by setting add_special_tokens=False
        tokenizer = self.info.get_tokenizer()
        assert prompt_tokens[0] == tokenizer.bos_token_id

        return prompt_tokens[1:]

190
191
192
193
194
195
196
197
198
199
200
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            audio_features=MultiModalFieldConfig.batched("audio"),
            audio_token_len=MultiModalFieldConfig.batched("audio"),
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

201
202
203
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
204
        hf_processor_mm_kwargs: Mapping[str, Any],
205
        out_mm_kwargs: MultiModalKwargs,
206
    ) -> list[PromptReplacement]:
207
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
208
        placeholder = hf_processor.audio_token_replacement  # type: ignore
209
210

        def get_replacement_ultravox(item_idx: int):
211
            audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
212
213
214
215
216
217
218
219
220
            return placeholder * audio_token_len

        return [
            PromptReplacement(
                modality="audio",
                target="<|audio|>",
                replacement=get_replacement_ultravox,
            )
        ]
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim)
        self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
        dim = self.hidden_dim

        if config.projector_act == "swiglu":
256
            self.act = MulAndSilu()
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
            dim = dim // 2
        else:
            self.act = get_act_fn(config.projector_act)

        self.linear_2 = nn.Linear(dim,
                                  config.text_config.hidden_size,
                                  bias=False)
        self.ln_post = RMSNorm(config.text_config.hidden_size)

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

    def forward(
        self,
        input_features,
    ):
        expected_seq_length = (self.config.max_source_positions *
                               self.conv1.stride[0] * self.conv2.stride[0])
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
                None,
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


332
333
334
335
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor,
                                        info=UltravoxProcessingInfo,
                                        dummy_inputs=UltravoxDummyInputsBuilder
                                        )
336
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
337

338
339
340
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

341
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
342
        super().__init__()
343
344
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
345
346
347
348
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

349
350
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
351
        if config.audio_model_id is not None:
352
353
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
354
355
356
357
358
359
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
360
361
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
362
            vllm_config=vllm_config,
363
364
365
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
366
        if config.text_model_id is not None:
367
368
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
369
370
371
372
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
373

374
375
376
377
378
379
380
381
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
382
        return get_sampler()
383

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    def _audio_features_to_embeddings(
            self, input_features: torch.Tensor) -> torch.Tensor:
        audio_input = input_features.to(self.audio_tower.dtype)
        audio_features = self.audio_tower(audio_input)
        audio_features = audio_features.to(self.audio_tower.dtype)
        audio_embeddings = self.multi_modal_projector(audio_features)
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")

            return UltravoxAudioFeatureInputs(type="audio_features",
                                              data=audio_features)

        if audio_embeds is not None:
409
            if not isinstance(audio_embeds, (torch.Tensor, list)):
410
411
412
413
414
415
416
417
418
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
419
            self, audio_input: UltravoxAudioInputs) -> NestedTensors:
420
421
422
423
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

        audio_features = audio_input["data"]
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        if isinstance(audio_features, torch.Tensor):
            # Combine the B and N dimensions for the encoder/projector
            flattened = flatten_bn(audio_features)
            flattened_embeddings = self._audio_features_to_embeddings(
                flattened)

            # Restore the original dimensions
            embeddings = flattened_embeddings.unflatten(
                0, audio_features.shape[:2])
            return embeddings

        result = []
        # TODO: Batch heterogeneous tensors through the encoder/projector
        for audio_features_item in audio_features:
            if isinstance(audio_features_item, torch.Tensor):
                result.append(
                    self._audio_features_to_embeddings(audio_features_item))
            else:
                embeddings = [
                    # Add a batch dimension to embed it, then remove it.
                    self._audio_features_to_embeddings(tensor.unsqueeze(0)
                                                       ).squeeze(0)
                    for tensor in audio_features_item
                ]
                result.append(embeddings)

        return result
451

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return None
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:

468
469
470
471
472
473
474
475
476
            # TODO(ywang96): remove this block after v0 is deprecated.
            if not envs.VLLM_USE_V1:
                merge_multimodal_embeddings_from_map(
                    inputs_embeds, multimodal_embeddings,
                    attn_metadata.multi_modal_placeholder_index_maps["audio"])
            else:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, multimodal_embeddings,
                    _AUDIO_PLACEHOLDER_TOKEN)
477
478
479
480
481
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
482
483
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
484
485
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
486
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
487
488
489
490
491
492
493
494
495
496
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
497
            audio_features: A batch of audio inputs [B, N, 80, M].
498
        """
499

500
        if intermediate_tensors is not None:
501
            inputs_embeds = None
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            # TODO(ywang96): remove attn_metadata from get_input_embeddings
            # after v0 is deprecated
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings,
                                                      attn_metadata)
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
521
522
523
524
525
526
527
528
529
530
531
532
533
534
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

535
536
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
537
538
539

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
540
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)