ultravox.py 28.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
import copy
8
import inspect
9
from collections.abc import Iterable, Mapping, Sequence
10
from types import SimpleNamespace
11
from typing import Annotated, Any, Literal, TypeAlias
12
13
14
15

import torch
from torch import nn
from torch.nn import functional as F
16
from transformers import BatchFeature, ProcessorMixin
17
from transformers.modeling_utils import ModuleUtilsMixin
18
from transformers.models.whisper import WhisperFeatureExtractor
19
20
21
22
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
)
23

24
from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
26
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
27
from vllm.model_executor.layers.layernorm import RMSNorm
28
from vllm.model_executor.model_loader import DefaultModelLoader
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
33
34
35
36
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
37
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
38
from vllm.multimodal.processing import (
39
    BaseDummyInputsBuilder,
40
41
42
43
44
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
45
from vllm.sequence import IntermediateTensors
46
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
47
from vllm.utils.tensor_schema import TensorSchema, TensorShape
48

49
50
51
52
53
54
55
56
57
58
59
60
61
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
62

63
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
64
_MAX_ENCODER_BATCH_SIZE = 16
65
66


67
class UltravoxAudioFeatureInputs(TensorSchema):
68
    """
69
70
71
72
73
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
74
    """
75

76
    type: Literal["audio_features"]
77
    data: Annotated[
78
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
79
        TensorShape("bn", "nmb", "t"),
80
    ]
81
82
83
84
85
86
87
88
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
89
90
91


class UltravoxAudioEmbeddingInputs(TensorSchema):
92
    """
93
94
95
96
97
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
98
    """
99

100
    type: Literal["audio_embeds"]
101
    data: Annotated[
102
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
103
    ]
104
105


106
107
108
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
109
110


111
class UltravoxProcessingInfo(BaseProcessingInfo):
112
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
113
        config = self.ctx.model_config.hf_config
114
        hf_processor = self.ctx.get_hf_processor(**kwargs)
115
116
117

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
118
        # token, thus we override placeholder with a reserved token.
119
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
120
121
        hf_processor.audio_replacement_token_id = config.audio_token_index

122
        return hf_processor
123

124
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
125
        hf_processor = self.get_hf_processor(**kwargs)
126
127

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
128
        audio_processor = hf_processor.audio_processor  # type: ignore
129
130
131
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

132
133
134
135
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

136
137
138
139
140
141
142
143
144
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

145
146
147
148
    def get_target_channels(self) -> int:
        """Return target audio channels for Ultravox models (mono)."""
        return 1

149
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
150
        return {"audio": None}
151

152

153
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
154
155
156
157
158
159
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
160
        self,
161
162
        seq_len: int,
        mm_counts: Mapping[str, int],
163
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
164
    ) -> MultiModalDataDict:
165
        feature_extractor = self.info.get_feature_extractor()
166
167

        sampling_rate = feature_extractor.sampling_rate
168
169
170
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
171
172
        num_audios = mm_counts.get("audio", 0)

173
174
        audio_overrides = mm_options.get("audio") if mm_options else None

175
        return {
176
177
178
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
179
180
181
        }


182
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
183
    def _call_hf_processor(
184
185
        self,
        prompt: str,
186
187
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
188
        tok_kwargs: Mapping[str, object],
189
    ) -> BatchFeature:
190
        # Text-only input not supported in composite processor
191
        if not mm_data.get("audios", []):
192
            prompt_ids = self.info.get_tokenizer().encode(
193
194
                prompt, add_special_tokens=False
            )
195
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
196
197
198
199
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
200
        assert isinstance(audios, list)
201

202
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
203
204
        mm_kwargs = dict(
            **mm_kwargs,
205
            sampling_rate=feature_extractor.sampling_rate,
206
            include_audio_num_chunks=True,
207
208
        )

209
        item_processor_data = dict(**mm_data, audios=audios)
210

211
        # some tokenizer kwargs are incompatible with UltravoxProcessor
212
        tok_kwargs.pop("add_special_tokens", None)
213
214
215
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

216
217
218
219
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
220
            tok_kwargs=tok_kwargs,
221
        )
222
        output["audio_features"] = output.pop("audio_values")
223
224

        return output
225

226
227
228
229
230
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
231
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
232
        return dict(
233
234
235
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
236
237
238
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
239
240
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
241
242
243
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

244
    def _get_prompt_updates(
245
246
        self,
        mm_items: MultiModalDataItems,
247
        hf_processor_mm_kwargs: Mapping[str, Any],
248
        out_mm_kwargs: MultiModalKwargsItems,
249
    ) -> Sequence[PromptUpdate]:
250
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
251

252
253
254
255
256
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
257
258
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
259
260
261
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
262
        chunks_start_idx = torch.cat(
263
264
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
265
266

        def get_replacement_ultravox(item_idx: int):
267
268
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
269
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
270
            return [replacement_id] * int(audio_token_len)  # type: ignore
271
272
273
274

        return [
            PromptReplacement(
                modality="audio",
275
                target="<|audio|>",
276
277
278
                replacement=get_replacement_ultravox,
            )
        ]
279
280
281
282
283
284
285
286
287
288
289
290
291
292


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
293
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
294
295
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
296
297
298
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
299
300
301
        return audio_embeds


302
class UltravoxFeedForwardProjector(nn.Module):
303
304
305
306
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
307
308
309
310
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
311
312

        if config.projector_act == "swiglu":
313
            self.act = MulAndSilu()
314
            dim_mid = dim_mid // 2
315
316
317
        else:
            self.act = get_act_fn(config.projector_act)

318
        dim_out = config.text_config.hidden_size
319
320
321
322
323
324
325
326
327
328
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
329

330
331
332
    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
333
334
335
336
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
337
        hidden_states = self.ln_mid(hidden_states)
338
339
340
341
342
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.config = SimpleNamespace(is_decoder=False)

        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim_in = config.audio_config.hidden_size * config.stack_factor

        projector_audio_config = copy.deepcopy(config.audio_config)

        self.ln_pre = RMSNorm(dim_in)
        self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)

        self.embed_positions = nn.Embedding(
            projector_audio_config.max_source_positions,
            projector_audio_config.d_model,
        )

        self.layers = nn.ModuleList(
            [
                WhisperEncoderLayer(projector_audio_config)
                for _ in range(config.num_projector_layers)
            ]
        )

        self.ln_post = RMSNorm(projector_audio_config.d_model)
        self.linear_out = nn.Linear(
            projector_audio_config.d_model, config.text_config.hidden_size
        )

    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)

        max_len_stacked = audio_features.shape[1]
        attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
            None, :
        ].lt(audio_token_len[:, None])
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, attention_mask.shape, audio_features.dtype
        )

        hidden_states = self.ln_pre(audio_features)
        hidden_states = self.linear_in(hidden_states)

        positions = self.embed_positions(
            torch.arange(hidden_states.size(1), device=hidden_states.device)
        )
        hidden_states = hidden_states + positions

394
395
396
397
398
399
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

400
401
402
403
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
404
                **kwargs,
405
406
407
408
409
410
411
412
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.linear_out(hidden_states)
        return hidden_states


413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

432
433
434
435
436
437
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
438
439
440
441
442
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
443

444
    def get_attention_mask_by_audio_len(
445
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
446
    ):
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
462
463
464
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
465
466
467
468
469
470
471
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

472
473
    def forward(
        self,
474
        input_features: torch.Tensor,
475
        audio_lens: torch.Tensor | None = None,
476
    ):
477
        expected_seq_length = self.max_context_length
478
479
480
481
482
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
483
484
                f"features to {expected_seq_length}."
            )
485
486
487
488
489

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
490
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
491
492

        hidden_states = inputs_embeds + embed_pos
493
494
495
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
496

497
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
498

499
500
501
502
503
504
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

505
506
507
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
508
                attention_mask,
509
                **kwargs,
510
511
512
513
514
515
516
517
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


518
519
520
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
521
522
    dummy_inputs=UltravoxDummyInputsBuilder,
)
523
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
524
525
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
526
        "gate_up_proj": ["gate_proj", "up_proj"],
527
528
    }

529
    hf_to_vllm_mapper = WeightsMapper(
530
531
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
532

533
    @classmethod
534
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
535
536
537
538
539
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

540
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
541
        super().__init__()
542
        config: UltravoxConfig = vllm_config.model_config.hf_config
543
        multimodal_config = vllm_config.model_config.multimodal_config
544
545
546
547
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

548
        self.secondary_weights = []
549
        if config.audio_model_id is not None:
550
551
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
552
553
554
555
556
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
557
558
                )
            )
559
        if config.text_model_id is not None:
560
561
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
562
            self.secondary_weights.append(
563
564
565
566
567
568
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
            if config.num_projector_layers > 0:
                self.multi_modal_projector = UltravoxTransformerProjector(config)
            else:
                self.multi_modal_projector = UltravoxFeedForwardProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.wrapped_model_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

584
        self.make_empty_intermediate_tensors = (
585
586
            self.language_model.make_empty_intermediate_tensors
        )
587

588
589
590
591
592
593
594
595
596
597
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

598
    def _audio_features_to_embeddings(
599
600
601
602
        self,
        input_features: torch.Tensor,
        audio_lens: torch.Tensor,
        audio_token_len: torch.Tensor,
603
    ) -> torch.Tensor:
604
605
606
607
608
609
610
611
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
612
613
614
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
615
616
617
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
618
619
620
            batch_embeddings = self.multi_modal_projector(
                batch_features, audio_token_len[start:end]
            )
621
622
623
624
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
625
626
627
        return audio_embeddings

    def _parse_and_validate_audio_input(
628
        self, **kwargs: object
629
    ) -> UltravoxAudioInputs | None:
630
631
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
632
633
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
634
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
635
636
637
638
639

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
640
641
642
643
644
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
645
                num_chunks=audio_num_chunks,
646
            )
647
648

        if audio_embeds is not None:
649
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
650
651
652
653

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
654
655
        self,
        audio_input: UltravoxAudioInputs,
656
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
657
658
659
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

660
661
662
663
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

664
665
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
666

667
668
669
        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens, audio_token_len
        )
670
671
672
673

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
674

675
676
677
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
678
679
            embeddings.shape[0], -1
        )
680
681
682
683
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

684
685
        # Return one tensor per input audio
        embed_lens = [
686
687
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
688
689
        ]
        return flattened_embeddings.split(embed_lens)
690

691
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
692
693
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
694
            return []
695
696
697
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

698
    def embed_input_ids(
699
700
        self,
        input_ids: torch.Tensor,
701
        multimodal_embeddings: MultiModalEmbeddings | None = None,
702
        *,
703
        is_multimodal: torch.Tensor | None = None,
704
705
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
706
    ) -> torch.Tensor:
707
708
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
709
            return super().embed_input_ids(input_ids)
710

711
        return super().embed_input_ids(
712
713
714
715
716
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
717

718
719
    def forward(
        self,
720
        input_ids: torch.Tensor | None,
721
        positions: torch.Tensor,
722
723
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
724
        **kwargs,
725
    ) -> torch.Tensor | IntermediateTensors:
726
727
728
729
730
731
732
733
734
735
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
736
737
738
739
740
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
741

742
        """
743

744
        if intermediate_tensors is not None:
745
            inputs_embeds = None
746

747
748
749
750
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

751
752
753
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
754
755
        return hidden_states

756
757
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
758

759
760
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
761
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
762
763
764


def pad_and_concat_to_dim3(
765
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
766
767
768
769
770
771
772
773
774
775
776
777
778
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
779

780
781
782
783
784
785
786
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
787
    # Pad and concatenate:
788
789
790
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)