ultravox.py 19.8 KB
Newer Older
1
2
3
4
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""

import math
5
from functools import cached_property, lru_cache
6
7
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
                    Tuple, TypedDict, Union)
8
9
10
11
12
13

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
14
from transformers import BatchFeature
15
16
17
18
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

from vllm.attention import AttentionMetadata
19
from vllm.config import VllmConfig
20
from vllm.inputs import InputContext
21
22
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
from vllm.model_executor.model_loader.loader import DefaultModelLoader
25
from vllm.model_executor.sampling_metadata import SamplingMetadata
26
27
28
29
30
31
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        MultiModalDataDict,
                                        MultiModalDataItems, ProcessorInputs,
                                        PromptReplacement)
from vllm.sequence import IntermediateTensors
32
33
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

34
from .interfaces import SupportsMultiModal, SupportsPP
35
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
36
                    init_vllm_registered_model, maybe_prefix,
37
                    merge_multimodal_embeddings_from_map)
38

39
40
41
42
43
_AUDIO_TOKENS_PER_SECOND = 6.25


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
44
    data: NestedTensors
45
    """Shape: `(batch_size, num_audios, 80, M)`"""
46
47
48
49


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
50
    data: NestedTensors
51
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


@lru_cache
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
    return WhisperFeatureExtractor.from_pretrained(model_id)


def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
    return cached_feature_extractor(
        ctx.get_hf_config(UltravoxConfig).audio_model_id)


def get_ultravox_max_audio_tokens(ctx: InputContext):
    feature_extractor = whisper_feature_extractor(ctx)
    return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)


73
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
74

75
76
    def _get_feature_extractor(self) -> WhisperFeatureExtractor:
        return self._get_hf_processor().audio_processor.feature_extractor
77

78
79
80
81
82
83
84
    def _resample_audio(
        self,
        audio: np.ndarray,
        sr: int,
    ) -> Dict[str, Union[np.ndarray, int]]:
        # resample audio to the model's sampling rate
        feature_extractor = self._get_feature_extractor()
85
        if sr != feature_extractor.sampling_rate:
86
87
            try:
                import librosa
88
            except ImportError as exc:
89
                raise ImportError(
90
                    "Please install vllm[audio] for audio support.") from exc
91
92
93
94
            audio = librosa.resample(audio,
                                     orig_sr=sr,
                                     target_sr=feature_extractor.sampling_rate)
            sr = feature_extractor.sampling_rate
95
        return {"audio": audio, "sampling_rate": sr}
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    def _apply_hf_processor(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data or not mm_data.get("audio", None):
            return super()._apply_hf_processor(prompt, mm_data,
                                               mm_processor_kwargs)

        audio_data = mm_data["audio"]
        if not isinstance(audio_data, list):
            audio_data = [audio_data]

        # Ultravox processor doesn't support multiple inputs,
        # therefore we need to input text and audio one by one
        tokenizer = self._get_tokenizer()
        audio_features, audio_token_len = [], []
        processed_inputs = {}
        for audio, sr in audio_data:
            data = self._resample_audio(audio, sr)
            processed_inputs = super()._apply_hf_processor(
                prompt, data, mm_processor_kwargs)
            prompt = tokenizer.decode(processed_inputs["input_ids"][0],
                                      skip_special_tokens=False)
            audio_features.append(
                processed_inputs.pop("audio_values").squeeze(0))
            audio_token_len.append(
                processed_inputs.pop("audio_token_len").item())

        return dict(
            **processed_inputs,
            audio_features=audio_features,
            audio_token_len=audio_token_len,
        )

    def _get_processor_data(
        self,
        mm_data: MultiModalDataDict,
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        # Ultravox uses "audio" instead of "audios" as calling keyword
        processor_data, passthrough_data = super()._get_processor_data(mm_data)
        if "audios" in processor_data:
            processor_data["audio"] = processor_data.pop("audios")
        return processor_data, passthrough_data

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
        mm_processor_kwargs: Mapping[str, object],
    ) -> list[PromptReplacement]:
        hf_processor = self._get_hf_processor()
        placeholder = hf_processor.audio_token_replacement

        def get_replacement_ultravox(item_idx: int):
            audio_token_len = hf_inputs["audio_token_len"][item_idx]
            return placeholder * audio_token_len

        return [
            PromptReplacement(
                modality="audio",
                target="<|audio|>",
                replacement=get_replacement_ultravox,
            )
        ]
163

164
165
166
167
168
169
170
    def _get_dummy_mm_inputs(
        self,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        feature_extractor = self._get_feature_extractor()
        sampling_rate = feature_extractor.sampling_rate
        audio_len = feature_extractor.chunk_length * sampling_rate
171

172
173
174
        audio_count = mm_counts["audio"]
        audio = np.zeros(audio_len)
        data = {"audio": [(audio, sampling_rate)] * audio_count}
175

176
177
178
179
180
        return ProcessorInputs(
            prompt_text="<|audio|>" * audio_count,
            mm_data=data,
            mm_processor_kwargs={},
        )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class FlippedSiluAndMul(SiluAndMul):
    """Ultravox is trained with SwiGLU with flipped halves."""

    def forward(self, x: torch.Tensor):
        a, b = x.chunk(2, dim=-1)
        flipped = torch.cat((b, a), dim=-1)
        return super().forward(flipped)


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim)
        self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
        dim = self.hidden_dim

        if config.projector_act == "swiglu":
            self.act = FlippedSiluAndMul()
            dim = dim // 2
        else:
            self.act = get_act_fn(config.projector_act)

        self.linear_2 = nn.Linear(dim,
                                  config.text_config.hidden_size,
                                  bias=False)
        self.ln_post = RMSNorm(config.text_config.hidden_size)

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

    def forward(
        self,
        input_features,
    ):
        expected_seq_length = (self.config.max_source_positions *
                               self.conv1.stride[0] * self.conv2.stride[0])
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
                None,
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
    "audio", get_ultravox_max_audio_tokens)
303
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
304
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
305

306
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
307
        super().__init__()
308
309
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
310
311
312
313
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

314
315
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
316
        if config.audio_model_id is not None:
317
318
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
319
320
321
322
323
324
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
325
326
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
327
            vllm_config=vllm_config,
328
329
330
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
331
        if config.text_model_id is not None:
332
333
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
334
335
336
337
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
338

339
340
341
342
343
344
345
346
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
347
        return get_sampler()
348

349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    def _audio_features_to_embeddings(
            self, input_features: torch.Tensor) -> torch.Tensor:
        audio_input = input_features.to(self.audio_tower.dtype)
        audio_features = self.audio_tower(audio_input)
        audio_features = audio_features.to(self.audio_tower.dtype)
        audio_embeddings = self.multi_modal_projector(audio_features)
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")

            return UltravoxAudioFeatureInputs(type="audio_features",
                                              data=audio_features)

        if audio_embeds is not None:
374
            if not isinstance(audio_embeds, (torch.Tensor, list)):
375
376
377
378
379
380
381
382
383
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
384
            self, audio_input: UltravoxAudioInputs) -> NestedTensors:
385
386
387
388
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

        audio_features = audio_input["data"]
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        if isinstance(audio_features, torch.Tensor):
            # Combine the B and N dimensions for the encoder/projector
            flattened = flatten_bn(audio_features)
            flattened_embeddings = self._audio_features_to_embeddings(
                flattened)

            # Restore the original dimensions
            embeddings = flattened_embeddings.unflatten(
                0, audio_features.shape[:2])
            return embeddings

        result = []
        # TODO: Batch heterogeneous tensors through the encoder/projector
        for audio_features_item in audio_features:
            if isinstance(audio_features_item, torch.Tensor):
                result.append(
                    self._audio_features_to_embeddings(audio_features_item))
            else:
                embeddings = [
                    # Add a batch dimension to embed it, then remove it.
                    self._audio_features_to_embeddings(tensor.unsqueeze(0)
                                                       ).squeeze(0)
                    for tensor in audio_features_item
                ]
                result.append(embeddings)

        return result
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return None
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:

            # TODO(ywang96): use merge_multimodal_embeddings after
            # v0 is deprecated
            merge_multimodal_embeddings_from_map(
                inputs_embeds, multimodal_embeddings,
                attn_metadata.multi_modal_placeholder_index_maps["audio"])
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
443
444
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
445
446
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
447
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
448
449
450
451
452
453
454
455
456
457
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
458
            audio_features: A batch of audio inputs [B, N, 80, M].
459
        """
460

461
        if intermediate_tensors is not None:
462
            inputs_embeds = None
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            # TODO(ywang96): remove attn_metadata from get_input_embeddings
            # after v0 is deprecated
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      multimodal_embeddings,
                                                      attn_metadata)
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

496
497
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
498
499
500
501
502
        hf_to_vllm_mapper = WeightsMapper(
            orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
503
        return loader.load_weights(weights, mapper=hf_to_vllm_mapper)