ultravox.py 24.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
from collections.abc import Iterable, Mapping, Sequence
8
from typing import Annotated, Any, Literal, Optional, Union
9
10
11
12

import torch
from torch import nn
from torch.nn import functional as F
13
from transformers import BatchFeature, ProcessorMixin
14
15
16
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.model_loader import DefaultModelLoader
22
from vllm.model_executor.models.module_mapping import MultiModelKeys
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
25
26
27
28
29
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
30
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
31
32
33
34
35
36
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
37
from vllm.multimodal.profiling import BaseDummyInputsBuilder
38
from vllm.sequence import IntermediateTensors
39
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
40
from vllm.utils.tensor_schema import TensorSchema, TensorShape
41

42
43
44
45
46
47
48
49
50
51
52
53
54
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
55

56
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
57
_MAX_ENCODER_BATCH_SIZE = 16
58
59


60
class UltravoxAudioFeatureInputs(TensorSchema):
61
    """
62
63
64
65
66
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
67
    """
68

69
    type: Literal["audio_features"]
70
71
72
73
74
75
76
77
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]],
        TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"}),
    ]
    lens: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("b", "n", dynamic_dims={"n"}),
    ]
78
    """Length of the audio frames. Used for attention mask in WhisperEncoder."""
79
80
81
82
    token_len: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("b", "n", dynamic_dims={"n"}),
    ]
83
84
85
86
    """Length of the audio tokens. Used for flattening the audio features."""


class UltravoxAudioEmbeddingInputs(TensorSchema):
87
    """
88
89
90
91
92
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
93
    """
94

95
    type: Literal["audio_embeds"]
96
97
98
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]], TensorShape("b", "na", "afs", "hs")
    ]
99
100


101
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioEmbeddingInputs]
102
103


104
class UltravoxProcessingInfo(BaseProcessingInfo):
105
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
106
        config = self.ctx.model_config.hf_config
107
        hf_processor = self.ctx.get_hf_processor(**kwargs)
108
109
110

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
111
        # token, thus we override placeholder with a reserved token.
112
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
113
114
        hf_processor.audio_replacement_token_id = config.audio_token_index

115
        return hf_processor
116

117
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
118
        hf_processor = self.get_hf_processor(**kwargs)
119
120
121
122
123
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

124
125
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
126

127

128
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
129
130
131
132
133
134
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
135
        self,
136
137
        seq_len: int,
        mm_counts: Mapping[str, int],
138
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
139
    ) -> MultiModalDataDict:
140
        feature_extractor = self.info.get_feature_extractor()
141
142

        sampling_rate = feature_extractor.sampling_rate
143
144
145
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
146
147
        num_audios = mm_counts.get("audio", 0)

148
149
        audio_overrides = mm_options.get("audio") if mm_options else None

150
        return {
151
152
153
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
154
155
156
        }


157
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
158
    def _get_data_parser(self) -> MultiModalDataParser:
159
        feature_extractor = self.info.get_feature_extractor()
160
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
161
162

    def _call_hf_processor(
163
164
        self,
        prompt: str,
165
166
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
167
        tok_kwargs: Mapping[str, object],
168
    ) -> BatchFeature:
169
        # Text-only input not supported in composite processor
170
        if not mm_data.get("audios", []):
171
            prompt_ids = self.info.get_tokenizer().encode(
172
173
                prompt, add_special_tokens=False
            )
174
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
175
176
177
178
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
179
        assert isinstance(audios, list)
180

181
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
182
183
        mm_kwargs = dict(
            **mm_kwargs,
184
            sampling_rate=feature_extractor.sampling_rate,
185
            include_audio_num_chunks=True,
186
187
        )

188
        item_processor_data = dict(**mm_data, audios=audios)
189

190
191
192
193
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

194
195
196
197
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
198
            tok_kwargs=tok_kwargs,
199
        )
200
        output["audio_features"] = output.pop("audio_values")
201
202

        return output
203

204
205
206
207
208
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
209
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
210
        return dict(
211
212
213
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
214
215
216
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
217
218
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
219
220
221
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

222
    def _get_prompt_updates(
223
224
        self,
        mm_items: MultiModalDataItems,
225
        hf_processor_mm_kwargs: Mapping[str, Any],
226
        out_mm_kwargs: MultiModalKwargsItems,
227
    ) -> Sequence[PromptUpdate]:
228
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
229

230
231
232
233
234
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
235
236
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
237
238
239
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
240
        chunks_start_idx = torch.cat(
241
242
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
243
244

        def get_replacement_ultravox(item_idx: int):
245
246
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
247
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
248
            return [replacement_id] * int(audio_token_len)  # type: ignore
249
250
251
252

        return [
            PromptReplacement(
                modality="audio",
253
                target="<|audio|>",
254
255
256
                replacement=get_replacement_ultravox,
            )
        ]
257
258
259
260
261
262
263
264
265
266
267
268
269
270


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
271
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
272
273
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
274
275
276
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
277
278
279
280
281
282
283
284
        return audio_embeds


class UltravoxProjector(nn.Module):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
285
286
287
288
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
289
290

        if config.projector_act == "swiglu":
291
            self.act = MulAndSilu()
292
            dim_mid = dim_mid // 2
293
294
295
        else:
            self.act = get_act_fn(config.projector_act)

296
        dim_out = config.text_config.hidden_size
297
298
299
300
301
302
303
304
305
306
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
307
308
309
310
311
312

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
313
        hidden_states = self.ln_mid(hidden_states)
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

338
339
340
341
342
343
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
344
345
346
347
348
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
349

350
351
352
    def get_attention_mask_by_audio_len(
        self, audio_lens: Optional[torch.Tensor], hidden_states: torch.Tensor
    ):
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
368
369
370
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
371
372
373
374
375
376
377
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

378
379
    def forward(
        self,
380
381
        input_features: torch.Tensor,
        audio_lens: Optional[torch.Tensor] = None,
382
    ):
383
        expected_seq_length = self.max_context_length
384
385
386
387
388
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
389
390
                f"features to {expected_seq_length}."
            )
391
392
393
394
395

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
396
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
397
398

        hidden_states = inputs_embeds + embed_pos
399
400
401
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
402

403
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
404

405
406
407
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
408
                attention_mask,
409
410
411
412
413
414
415
416
417
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


418
419
420
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
421
422
    dummy_inputs=UltravoxDummyInputsBuilder,
)
423
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
424
425
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
426
        "gate_up_proj": ["gate_proj", "up_proj"],
427
428
    }

429
    hf_to_vllm_mapper = WeightsMapper(
430
431
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
432

433
434
435
436
437
438
439
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

440
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
441
        super().__init__()
442
        config: UltravoxConfig = vllm_config.model_config.hf_config
443
        multimodal_config = vllm_config.model_config.multimodal_config
444
445
446
447
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

448
449
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
450
        if config.audio_model_id is not None:
451
452
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
453
454
455
456
457
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
458
459
                )
            )
460
461
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
462
            vllm_config=vllm_config,
463
            hf_config=config.wrapped_model_config,
464
465
            prefix=maybe_prefix(prefix, "language_model"),
        )
466
        if config.text_model_id is not None:
467
468
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
469
            self.secondary_weights.append(
470
471
472
473
474
475
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
476

477
        self.make_empty_intermediate_tensors = (
478
479
            self.language_model.make_empty_intermediate_tensors
        )
480

481
482
483
484
485
486
487
488
489
490
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

491
    def _audio_features_to_embeddings(
492
493
        self, input_features: torch.Tensor, audio_lens: torch.Tensor
    ) -> torch.Tensor:
494
495
496
497
498
499
500
501
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
502
503
504
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
505
506
507
508
509
510
511
512
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
            batch_embeddings = self.multi_modal_projector(batch_features)
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
513
514
515
        return audio_embeddings

    def _parse_and_validate_audio_input(
516
517
        self, **kwargs: object
    ) -> Optional[UltravoxAudioInputs]:
518
519
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
520
521
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
522
523
524
525
526

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
527
528
529
530
531
532
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
            )
533
534

        if audio_embeds is not None:
535
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
536
537
538
539

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
540
541
542
        self,
        audio_input: UltravoxAudioInputs,
    ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
543
544
545
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

546
547
548
549
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

550
        # [B1, B2] -> [B1+B2]
551
552
        audio_lens = flatten_bn(audio_input["lens"], concat=True)
        audio_token_len = flatten_bn(audio_input["token_len"], concat=True)
553

554
        embeddings = self._audio_features_to_embeddings(audio_features, audio_lens)
555
556
557
558

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
559

560
561
562
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
563
564
            embeddings.shape[0], -1
        )
565
566
567
568
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

569
570
        # Return one tensor per input audio
        embed_lens = [
571
            token_len_item.sum().item() for token_len_item in audio_input["token_len"]
572
573
        ]
        return flattened_embeddings.split(embed_lens)
574

575
576
577
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

578
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
579
580
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
581
            return []
582
583
584
585
586
587
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
588
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
589
590
591
592
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
593
    ) -> torch.Tensor:
594
595
596
597
598
599
600
601
602
603
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
604

605
606
607
608
609
610
611
612
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
613
614
615
616
617
618
619
620
621
622
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
623
624
625
626
627
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
628

629
        """
630

631
        if intermediate_tensors is not None:
632
            inputs_embeds = None
633

634
635
636
637
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

638
639
640
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
641
642
        return hidden_states

643
644
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
645

646
647
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
648
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
649
650
651


def pad_and_concat_to_dim3(
652
    features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]],
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
673
    # Pad and concatenate:
674
675
676
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)