ultravox.py 25.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
from collections.abc import Iterable, Mapping, Sequence
8
from typing import Annotated, Any, Literal, TypeAlias
9
10
11
12

import torch
from torch import nn
from torch.nn import functional as F
13
from transformers import BatchFeature, ProcessorMixin
14
15
16
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.model_loader import DefaultModelLoader
22
from vllm.model_executor.models.module_mapping import MultiModelKeys
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
25
26
27
28
29
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
30
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
31
32
33
34
35
36
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
37
from vllm.multimodal.profiling import BaseDummyInputsBuilder
38
from vllm.sequence import IntermediateTensors
39
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
40
from vllm.utils.tensor_schema import TensorSchema, TensorShape
41

42
43
44
45
46
47
48
49
50
51
52
53
54
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
55

56
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
57
_MAX_ENCODER_BATCH_SIZE = 16
58
59


60
class UltravoxAudioFeatureInputs(TensorSchema):
61
    """
62
63
64
65
66
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
67
    """
68

69
    type: Literal["audio_features"]
70
    data: Annotated[
71
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
72
        TensorShape("bn", "nmb", "t"),
73
    ]
74
75
76
77
78
79
80
81
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
82
83
84


class UltravoxAudioEmbeddingInputs(TensorSchema):
85
    """
86
87
88
89
90
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
91
    """
92

93
    type: Literal["audio_embeds"]
94
    data: Annotated[
95
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
96
    ]
97
98


99
100
101
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
102
103


104
class UltravoxProcessingInfo(BaseProcessingInfo):
105
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
106
        config = self.ctx.model_config.hf_config
107
        hf_processor = self.ctx.get_hf_processor(**kwargs)
108
109
110

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
111
        # token, thus we override placeholder with a reserved token.
112
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
113
114
        hf_processor.audio_replacement_token_id = config.audio_token_index

115
        return hf_processor
116

117
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
118
        hf_processor = self.get_hf_processor(**kwargs)
119
120

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
121
        audio_processor = hf_processor.audio_processor  # type: ignore
122
123
124
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

125
126
127
128
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

129
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
130
        return {"audio": None}
131

132

133
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
134
135
136
137
138
139
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
140
        self,
141
142
        seq_len: int,
        mm_counts: Mapping[str, int],
143
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
144
    ) -> MultiModalDataDict:
145
        feature_extractor = self.info.get_feature_extractor()
146
147

        sampling_rate = feature_extractor.sampling_rate
148
149
150
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
151
152
        num_audios = mm_counts.get("audio", 0)

153
154
        audio_overrides = mm_options.get("audio") if mm_options else None

155
        return {
156
157
158
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
159
160
161
        }


162
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
163
    def _get_data_parser(self) -> MultiModalDataParser:
164
        feature_extractor = self.info.get_feature_extractor()
165
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
166
167

    def _call_hf_processor(
168
169
        self,
        prompt: str,
170
171
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
172
        tok_kwargs: Mapping[str, object],
173
    ) -> BatchFeature:
174
        # Text-only input not supported in composite processor
175
        if not mm_data.get("audios", []):
176
            prompt_ids = self.info.get_tokenizer().encode(
177
178
                prompt, add_special_tokens=False
            )
179
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
180
181
182
183
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
184
        assert isinstance(audios, list)
185

186
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
187
188
        mm_kwargs = dict(
            **mm_kwargs,
189
            sampling_rate=feature_extractor.sampling_rate,
190
            include_audio_num_chunks=True,
191
192
        )

193
        item_processor_data = dict(**mm_data, audios=audios)
194

195
196
197
198
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

199
200
201
202
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
203
            tok_kwargs=tok_kwargs,
204
        )
205
        output["audio_features"] = output.pop("audio_values")
206
207

        return output
208

209
210
211
212
213
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
214
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
215
        return dict(
216
217
218
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
219
220
221
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
222
223
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
224
225
226
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

227
    def _get_prompt_updates(
228
229
        self,
        mm_items: MultiModalDataItems,
230
        hf_processor_mm_kwargs: Mapping[str, Any],
231
        out_mm_kwargs: MultiModalKwargsItems,
232
    ) -> Sequence[PromptUpdate]:
233
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
234

235
236
237
238
239
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
240
241
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
242
243
244
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
245
        chunks_start_idx = torch.cat(
246
247
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
248
249

        def get_replacement_ultravox(item_idx: int):
250
251
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
252
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
253
            return [replacement_id] * int(audio_token_len)  # type: ignore
254
255
256
257

        return [
            PromptReplacement(
                modality="audio",
258
                target="<|audio|>",
259
260
261
                replacement=get_replacement_ultravox,
            )
        ]
262
263
264
265
266
267
268
269
270
271
272
273
274
275


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
276
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
277
278
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
279
280
281
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
282
283
284
285
286
287
288
289
        return audio_embeds


class UltravoxProjector(nn.Module):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
290
291
292
293
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
294
295

        if config.projector_act == "swiglu":
296
            self.act = MulAndSilu()
297
            dim_mid = dim_mid // 2
298
299
300
        else:
            self.act = get_act_fn(config.projector_act)

301
        dim_out = config.text_config.hidden_size
302
303
304
305
306
307
308
309
310
311
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
312
313
314
315
316
317

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
318
        hidden_states = self.ln_mid(hidden_states)
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

343
344
345
346
347
348
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
349
350
351
352
353
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
354

355
    def get_attention_mask_by_audio_len(
356
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
357
    ):
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
373
374
375
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
376
377
378
379
380
381
382
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

383
384
    def forward(
        self,
385
        input_features: torch.Tensor,
386
        audio_lens: torch.Tensor | None = None,
387
    ):
388
        expected_seq_length = self.max_context_length
389
390
391
392
393
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
394
395
                f"features to {expected_seq_length}."
            )
396
397
398
399
400

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
401
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
402
403

        hidden_states = inputs_embeds + embed_pos
404
405
406
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
407

408
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
409

410
411
412
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
413
                attention_mask,
414
415
416
417
418
419
420
421
422
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


423
424
425
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
426
427
    dummy_inputs=UltravoxDummyInputsBuilder,
)
428
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
429
430
    merge_by_field_config = True

431
432
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
433
        "gate_up_proj": ["gate_proj", "up_proj"],
434
435
    }

436
    hf_to_vllm_mapper = WeightsMapper(
437
438
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
439

440
    @classmethod
441
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
442
443
444
445
446
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

447
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
448
        super().__init__()
449
        config: UltravoxConfig = vllm_config.model_config.hf_config
450
        multimodal_config = vllm_config.model_config.multimodal_config
451
452
453
454
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

455
456
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
457
        if config.audio_model_id is not None:
458
459
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
460
461
462
463
464
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
465
466
                )
            )
467
468
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
469
            vllm_config=vllm_config,
470
            hf_config=config.wrapped_model_config,
471
472
            prefix=maybe_prefix(prefix, "language_model"),
        )
473
        if config.text_model_id is not None:
474
475
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
476
            self.secondary_weights.append(
477
478
479
480
481
482
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
483

484
        self.make_empty_intermediate_tensors = (
485
486
            self.language_model.make_empty_intermediate_tensors
        )
487

488
489
490
491
492
493
494
495
496
497
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

498
    def _audio_features_to_embeddings(
499
500
        self, input_features: torch.Tensor, audio_lens: torch.Tensor
    ) -> torch.Tensor:
501
502
503
504
505
506
507
508
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
509
510
511
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
512
513
514
515
516
517
518
519
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
            batch_embeddings = self.multi_modal_projector(batch_features)
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
520
521
522
        return audio_embeddings

    def _parse_and_validate_audio_input(
523
        self, **kwargs: object
524
    ) -> UltravoxAudioInputs | None:
525
526
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
527
528
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
529
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
530
531
532
533
534

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
535
536
537
538
539
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
540
                num_chunks=audio_num_chunks,
541
            )
542
543

        if audio_embeds is not None:
544
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
545
546
547
548

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
549
550
        self,
        audio_input: UltravoxAudioInputs,
551
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
552
553
554
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

555
556
557
558
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

559
560
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
561

562
        embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
563
564
565
566

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
567

568
569
570
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
571
572
            embeddings.shape[0], -1
        )
573
574
575
576
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

577
578
        # Return one tensor per input audio
        embed_lens = [
579
580
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
581
582
        ]
        return flattened_embeddings.split(embed_lens)
583

584
585
586
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

587
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
588
589
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
590
            return []
591
592
593
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

594
    def embed_input_ids(
595
596
        self,
        input_ids: torch.Tensor,
597
        multimodal_embeddings: MultiModalEmbeddings | None = None,
598
        *,
599
        is_multimodal: torch.Tensor | None = None,
600
601
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
602
    ) -> torch.Tensor:
603
604
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
605
            return super().embed_input_ids(input_ids)
606

607
        return super().embed_input_ids(
608
609
610
611
612
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
613

614
615
616
617
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
618
619
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
620
        **kwargs,
621
    ) -> torch.Tensor | IntermediateTensors:
622
623
624
625
626
627
628
629
630
631
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
632
633
634
635
636
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
637

638
        """
639

640
        if intermediate_tensors is not None:
641
            inputs_embeds = None
642

643
644
645
646
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

647
648
649
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
650
651
        return hidden_states

652
653
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
654

655
656
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
657
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
658
659
660


def pad_and_concat_to_dim3(
661
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
662
663
664
665
666
667
668
669
670
671
672
673
674
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
675

676
677
678
679
680
681
682
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
683
    # Pad and concatenate:
684
685
686
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)