ultravox.py 27.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
import copy
8
from collections.abc import Iterable, Mapping, Sequence
9
from types import SimpleNamespace
10
from typing import Annotated, Any, Literal, TypeAlias
11
12
13
14

import torch
from torch import nn
from torch.nn import functional as F
15
from transformers import BatchFeature, ProcessorMixin
16
from transformers.modeling_utils import ModuleUtilsMixin
17
from transformers.models.whisper import WhisperFeatureExtractor
18
19
20
21
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
)
22

23
from vllm.config import VllmConfig
24
from vllm.config.multimodal import BaseDummyOptions
25
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
26
from vllm.model_executor.layers.layernorm import RMSNorm
27
from vllm.model_executor.model_loader import DefaultModelLoader
28
from vllm.model_executor.models.module_mapping import MultiModelKeys
29
from vllm.multimodal import MULTIMODAL_REGISTRY
30
31
32
33
34
35
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
36
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
37
38
39
40
41
42
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
43
from vllm.multimodal.profiling import BaseDummyInputsBuilder
44
from vllm.sequence import IntermediateTensors
45
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
46
from vllm.utils.tensor_schema import TensorSchema, TensorShape
47

48
49
50
51
52
53
54
55
56
57
58
59
60
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
61

62
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
63
_MAX_ENCODER_BATCH_SIZE = 16
64
65


66
class UltravoxAudioFeatureInputs(TensorSchema):
67
    """
68
69
70
71
72
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
73
    """
74

75
    type: Literal["audio_features"]
76
    data: Annotated[
77
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
78
        TensorShape("bn", "nmb", "t"),
79
    ]
80
81
82
83
84
85
86
87
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
88
89
90


class UltravoxAudioEmbeddingInputs(TensorSchema):
91
    """
92
93
94
95
96
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
97
    """
98

99
    type: Literal["audio_embeds"]
100
    data: Annotated[
101
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
102
    ]
103
104


105
106
107
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
108
109


110
class UltravoxProcessingInfo(BaseProcessingInfo):
111
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
112
        config = self.ctx.model_config.hf_config
113
        hf_processor = self.ctx.get_hf_processor(**kwargs)
114
115
116

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
117
        # token, thus we override placeholder with a reserved token.
118
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
119
120
        hf_processor.audio_replacement_token_id = config.audio_token_index

121
        return hf_processor
122

123
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
124
        hf_processor = self.get_hf_processor(**kwargs)
125
126

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
127
        audio_processor = hf_processor.audio_processor  # type: ignore
128
129
130
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

131
132
133
134
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

135
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
136
        return {"audio": None}
137

138

139
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
140
141
142
143
144
145
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
146
        self,
147
148
        seq_len: int,
        mm_counts: Mapping[str, int],
149
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
150
    ) -> MultiModalDataDict:
151
        feature_extractor = self.info.get_feature_extractor()
152
153

        sampling_rate = feature_extractor.sampling_rate
154
155
156
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
157
158
        num_audios = mm_counts.get("audio", 0)

159
160
        audio_overrides = mm_options.get("audio") if mm_options else None

161
        return {
162
163
164
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
165
166
167
        }


168
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
169
    def _get_data_parser(self) -> MultiModalDataParser:
170
        feature_extractor = self.info.get_feature_extractor()
171
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
172
173

    def _call_hf_processor(
174
175
        self,
        prompt: str,
176
177
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
178
        tok_kwargs: Mapping[str, object],
179
    ) -> BatchFeature:
180
        # Text-only input not supported in composite processor
181
        if not mm_data.get("audios", []):
182
            prompt_ids = self.info.get_tokenizer().encode(
183
184
                prompt, add_special_tokens=False
            )
185
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
186
187
188
189
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
190
        assert isinstance(audios, list)
191

192
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
193
194
        mm_kwargs = dict(
            **mm_kwargs,
195
            sampling_rate=feature_extractor.sampling_rate,
196
            include_audio_num_chunks=True,
197
198
        )

199
        item_processor_data = dict(**mm_data, audios=audios)
200

201
202
203
204
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

205
206
207
208
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
209
            tok_kwargs=tok_kwargs,
210
        )
211
        output["audio_features"] = output.pop("audio_values")
212
213

        return output
214

215
216
217
218
219
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
220
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
221
        return dict(
222
223
224
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
225
226
227
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
228
229
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
230
231
232
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

233
    def _get_prompt_updates(
234
235
        self,
        mm_items: MultiModalDataItems,
236
        hf_processor_mm_kwargs: Mapping[str, Any],
237
        out_mm_kwargs: MultiModalKwargsItems,
238
    ) -> Sequence[PromptUpdate]:
239
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
240

241
242
243
244
245
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
246
247
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
248
249
250
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
251
        chunks_start_idx = torch.cat(
252
253
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
254
255

        def get_replacement_ultravox(item_idx: int):
256
257
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
258
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
259
            return [replacement_id] * int(audio_token_len)  # type: ignore
260
261
262
263

        return [
            PromptReplacement(
                modality="audio",
264
                target="<|audio|>",
265
266
267
                replacement=get_replacement_ultravox,
            )
        ]
268
269
270
271
272
273
274
275
276
277
278
279
280
281


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
282
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
283
284
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
285
286
287
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
288
289
290
        return audio_embeds


291
class UltravoxFeedForwardProjector(nn.Module):
292
293
294
295
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
296
297
298
299
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
300
301

        if config.projector_act == "swiglu":
302
            self.act = MulAndSilu()
303
            dim_mid = dim_mid // 2
304
305
306
        else:
            self.act = get_act_fn(config.projector_act)

307
        dim_out = config.text_config.hidden_size
308
309
310
311
312
313
314
315
316
317
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
318

319
320
321
    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
322
323
324
325
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
326
        hidden_states = self.ln_mid(hidden_states)
327
328
329
330
331
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.config = SimpleNamespace(is_decoder=False)

        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim_in = config.audio_config.hidden_size * config.stack_factor

        projector_audio_config = copy.deepcopy(config.audio_config)

        self.ln_pre = RMSNorm(dim_in)
        self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)

        self.embed_positions = nn.Embedding(
            projector_audio_config.max_source_positions,
            projector_audio_config.d_model,
        )

        self.layers = nn.ModuleList(
            [
                WhisperEncoderLayer(projector_audio_config)
                for _ in range(config.num_projector_layers)
            ]
        )

        self.ln_post = RMSNorm(projector_audio_config.d_model)
        self.linear_out = nn.Linear(
            projector_audio_config.d_model, config.text_config.hidden_size
        )

    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)

        max_len_stacked = audio_features.shape[1]
        attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
            None, :
        ].lt(audio_token_len[:, None])
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, attention_mask.shape, audio_features.dtype
        )

        hidden_states = self.ln_pre(audio_features)
        hidden_states = self.linear_in(hidden_states)

        positions = self.embed_positions(
            torch.arange(hidden_states.size(1), device=hidden_states.device)
        )
        hidden_states = hidden_states + positions

        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
                layer_head_mask=None,
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.linear_out(hidden_states)
        return hidden_states


396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

415
416
417
418
419
420
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
421
422
423
424
425
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
426

427
    def get_attention_mask_by_audio_len(
428
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
429
    ):
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
445
446
447
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
448
449
450
451
452
453
454
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

455
456
    def forward(
        self,
457
        input_features: torch.Tensor,
458
        audio_lens: torch.Tensor | None = None,
459
    ):
460
        expected_seq_length = self.max_context_length
461
462
463
464
465
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
466
467
                f"features to {expected_seq_length}."
            )
468
469
470
471
472

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
473
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
474
475

        hidden_states = inputs_embeds + embed_pos
476
477
478
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
479

480
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
481

482
483
484
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
485
                attention_mask,
486
487
488
489
490
491
492
493
494
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


495
496
497
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
498
499
    dummy_inputs=UltravoxDummyInputsBuilder,
)
500
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
501
502
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
503
        "gate_up_proj": ["gate_proj", "up_proj"],
504
505
    }

506
    hf_to_vllm_mapper = WeightsMapper(
507
508
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
509

510
    @classmethod
511
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
512
513
514
515
516
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

517
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
518
        super().__init__()
519
        config: UltravoxConfig = vllm_config.model_config.hf_config
520
        multimodal_config = vllm_config.model_config.multimodal_config
521
522
523
524
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

525
526
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
527
        if config.audio_model_id is not None:
528
529
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
530
531
532
533
534
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
535
536
                )
            )
537
538
539
540
        if config.num_projector_layers > 0:
            self.multi_modal_projector = UltravoxTransformerProjector(config)
        else:
            self.multi_modal_projector = UltravoxFeedForwardProjector(config)
541
        self.language_model = init_vllm_registered_model(
542
            vllm_config=vllm_config,
543
            hf_config=config.wrapped_model_config,
544
545
            prefix=maybe_prefix(prefix, "language_model"),
        )
546
        if config.text_model_id is not None:
547
548
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
549
            self.secondary_weights.append(
550
551
552
553
554
555
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
556

557
        self.make_empty_intermediate_tensors = (
558
559
            self.language_model.make_empty_intermediate_tensors
        )
560

561
562
563
564
565
566
567
568
569
570
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

571
    def _audio_features_to_embeddings(
572
573
574
575
        self,
        input_features: torch.Tensor,
        audio_lens: torch.Tensor,
        audio_token_len: torch.Tensor,
576
    ) -> torch.Tensor:
577
578
579
580
581
582
583
584
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
585
586
587
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
588
589
590
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
591
592
593
            batch_embeddings = self.multi_modal_projector(
                batch_features, audio_token_len[start:end]
            )
594
595
596
597
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
598
599
600
        return audio_embeddings

    def _parse_and_validate_audio_input(
601
        self, **kwargs: object
602
    ) -> UltravoxAudioInputs | None:
603
604
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
605
606
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
607
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
608
609
610
611
612

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
613
614
615
616
617
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
618
                num_chunks=audio_num_chunks,
619
            )
620
621

        if audio_embeds is not None:
622
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
623
624
625
626

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
627
628
        self,
        audio_input: UltravoxAudioInputs,
629
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
630
631
632
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

633
634
635
636
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

637
638
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
639

640
641
642
        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens, audio_token_len
        )
643
644
645
646

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
647

648
649
650
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
651
652
            embeddings.shape[0], -1
        )
653
654
655
656
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

657
658
        # Return one tensor per input audio
        embed_lens = [
659
660
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
661
662
        ]
        return flattened_embeddings.split(embed_lens)
663

664
665
666
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

667
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
668
669
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
670
            return []
671
672
673
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

674
    def embed_input_ids(
675
676
        self,
        input_ids: torch.Tensor,
677
        multimodal_embeddings: MultiModalEmbeddings | None = None,
678
        *,
679
        is_multimodal: torch.Tensor | None = None,
680
681
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
682
    ) -> torch.Tensor:
683
684
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
685
            return super().embed_input_ids(input_ids)
686

687
        return super().embed_input_ids(
688
689
690
691
692
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
693

694
695
696
697
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
698
699
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
700
        **kwargs,
701
    ) -> torch.Tensor | IntermediateTensors:
702
703
704
705
706
707
708
709
710
711
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
712
713
714
715
716
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
717

718
        """
719

720
        if intermediate_tensors is not None:
721
            inputs_embeds = None
722

723
724
725
726
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

727
728
729
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
730
731
        return hidden_states

732
733
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
734

735
736
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
737
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
738
739
740


def pad_and_concat_to_dim3(
741
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
742
743
744
745
746
747
748
749
750
751
752
753
754
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
755

756
757
758
759
760
761
762
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
763
    # Pad and concatenate:
764
765
766
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)