ultravox.py 25.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6
from collections.abc import Iterable, Mapping, Sequence
7
from typing import Annotated, Any, Literal, Optional, Union
8
9
10
11

import torch
from torch import nn
from torch.nn import functional as F
12
from transformers import BatchFeature, ProcessorMixin
13
14
15
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

16
from vllm.config import VllmConfig
17
from vllm.config.multimodal import BaseDummyOptions
18
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
19
from vllm.model_executor.layers.layernorm import RMSNorm
20
from vllm.model_executor.model_loader import DefaultModelLoader
21
from vllm.model_executor.models.module_mapping import MultiModelKeys
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
24
                                    MultiModalKwargsItems, NestedTensors)
25
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
26
from vllm.multimodal.processing import (BaseMultiModalProcessor,
27
28
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
29
from vllm.multimodal.profiling import BaseDummyInputsBuilder
30
from vllm.sequence import IntermediateTensors
31
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
32
from vllm.utils.tensor_schema import TensorSchema, TensorShape
33

34
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
35
                         SupportsMultiModal, SupportsPP)
36
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
37
                    init_vllm_registered_model, maybe_prefix)
38

39
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
40
_MAX_ENCODER_BATCH_SIZE = 16
41
42


43
class UltravoxAudioFeatureInputs(TensorSchema):
44
    """
45
46
47
48
49
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
50
    """
51
52
53
54
55
56
57
58
59
60
61
62
63
    type: Literal["audio_features"]
    data: Annotated[Union[torch.Tensor, list[torch.Tensor],
                          list[list[torch.Tensor]]],
                    TensorShape("b", "n", "nmb", "t", dynamic_dims={"n"})]
    lens: Annotated[Union[torch.Tensor, list[torch.Tensor]],
                    TensorShape("b", "n", dynamic_dims={"n"})]
    """Length of the audio frames. Used for attention mask in WhisperEncoder."""
    token_len: Annotated[Union[torch.Tensor, list[torch.Tensor]],
                         TensorShape("b", "n", dynamic_dims={"n"})]
    """Length of the audio tokens. Used for flattening the audio features."""


class UltravoxAudioEmbeddingInputs(TensorSchema):
64
    """
65
66
67
68
69
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
70
    """
71
    type: Literal["audio_embeds"]
72
73
    data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
                    TensorShape("b", "na", "afs", "hs")]
74
75
76
77
78
79


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


80
class UltravoxProcessingInfo(BaseProcessingInfo):
81

82
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
83
        config = self.ctx.model_config.hf_config
84
        hf_processor = self.ctx.get_hf_processor(**kwargs)
85
86
87

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
88
        # token, thus we override placeholder with a reserved token.
89
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
90
91
        hf_processor.audio_replacement_token_id = config.audio_token_index

92
        return hf_processor
93

94
95
96
    def get_feature_extractor(self,
                              **kwargs: object) -> WhisperFeatureExtractor:
        hf_processor = self.get_hf_processor(**kwargs)
97
98
99
100
101
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

102
103
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
104

105
106
107
108

class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
                                 ):

109
110
111
112
113
114
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
115
        self,
116
117
        seq_len: int,
        mm_counts: Mapping[str, int],
118
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
119
    ) -> MultiModalDataDict:
120
        feature_extractor = self.info.get_feature_extractor()
121
122

        sampling_rate = feature_extractor.sampling_rate
123
124
        audio_len = (feature_extractor.chunk_length * sampling_rate *
                     _MAX_ENCODER_BATCH_SIZE)
125
126
        num_audios = mm_counts.get("audio", 0)

127
128
        audio_overrides = mm_options.get("audio") if mm_options else None

129
        return {
130
            "audio":
131
132
133
            self._get_dummy_audios(length=audio_len,
                                   num_audios=num_audios,
                                   overrides=audio_overrides)
134
135
136
        }


137
138
class UltravoxMultiModalProcessor(
        BaseMultiModalProcessor[UltravoxProcessingInfo]):
139

140
    def _get_data_parser(self) -> MultiModalDataParser:
141
        feature_extractor = self.info.get_feature_extractor()
142
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
143
144

    def _call_hf_processor(
145
146
        self,
        prompt: str,
147
148
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
149
        tok_kwargs: Mapping[str, object],
150
    ) -> BatchFeature:
151
        # Text-only input not supported in composite processor
152
        if not mm_data.get("audios", []):
153
154
            prompt_ids = self.info.get_tokenizer().encode(
                prompt, add_special_tokens=False)
155
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
156
157
158
159
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
160
        assert isinstance(audios, list)
161

162
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
163
164
        mm_kwargs = dict(
            **mm_kwargs,
165
            sampling_rate=feature_extractor.sampling_rate,
166
            include_audio_num_chunks=True,
167
168
        )

169
        item_processor_data = dict(**mm_data, audios=audios)
170

171
172
173
174
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

175
176
177
178
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
179
            tok_kwargs=tok_kwargs,
180
        )
181
182
183
        output['audio_features'] = output.pop('audio_values')

        return output
184

185
186
187
188
189
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
190
        num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0))
191
        return dict(
192
193
194
195
196
197
198
199
200
201
202
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
            audio_features=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
203
204
205
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

206
    def _get_prompt_updates(
207
208
        self,
        mm_items: MultiModalDataItems,
209
        hf_processor_mm_kwargs: Mapping[str, Any],
210
        out_mm_kwargs: MultiModalKwargsItems,
211
    ) -> Sequence[PromptUpdate]:
212
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
213

214
215
216
217
218
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
219
220
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
221
222
223
224
225
        chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
                                                      dim=0,
                                                      dtype=torch.int32)
        chunks_start_idx = torch.cat(
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx])
226
227

        def get_replacement_ultravox(item_idx: int):
228
229
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
230
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
231
            return [replacement_id] * int(audio_token_len)  # type: ignore
232
233
234
235

        return [
            PromptReplacement(
                modality="audio",
236
                target="<|audio|>",
237
238
239
                replacement=get_replacement_ultravox,
            )
        ]
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
269
270
271
272
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
273
274

        if config.projector_act == "swiglu":
275
            self.act = MulAndSilu()
276
            dim_mid = dim_mid // 2
277
278
279
        else:
            self.act = get_act_fn(config.projector_act)

280
        dim_out = config.text_config.hidden_size
281
282
283
284
285
286
287
288
289
290
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
291
292
293
294
295
296

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
297
        hidden_states = self.ln_mid(hidden_states)
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
        return (self.config.max_source_positions * self.conv1.stride[0] *
                self.conv2.stride[0])

    def get_attention_mask_by_audio_len(self,
                                        audio_lens: Optional[torch.Tensor],
                                        hidden_states: torch.Tensor):
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
        attention_mask = torch.arange(max_seq_len,
                                      device=hidden_states.device)[None, :].lt(
                                          audio_feature_len.view(-1, 1))
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

359
360
    def forward(
        self,
361
362
        input_features: torch.Tensor,
        audio_lens: Optional[torch.Tensor] = None,
363
    ):
364
        expected_seq_length = self.max_context_length
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

383
384
385
        attention_mask = self.get_attention_mask_by_audio_len(
            audio_lens, hidden_states)

386
387
388
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
389
                attention_mask,
390
391
392
393
394
395
396
397
398
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


399
400
401
402
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
    dummy_inputs=UltravoxDummyInputsBuilder)
403
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
404
405
406
407
408
409

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

410
411
412
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

413
414
415
416
417
418
419
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

420
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
421
        super().__init__()
422
        config: UltravoxConfig = vllm_config.model_config.hf_config
423
        multimodal_config = vllm_config.model_config.multimodal_config
424
425
426
427
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

428
429
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
430
        if config.audio_model_id is not None:
431
432
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
433
434
435
436
437
438
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
439
440
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
441
            vllm_config=vllm_config,
442
            hf_config=config.wrapped_model_config,
443
444
            prefix=maybe_prefix(prefix, "language_model"),
        )
445
        if config.text_model_id is not None:
446
447
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
448
449
450
451
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
452

453
454
455
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

456
457
458
459
460
461
462
463
464
465
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

466
    def _audio_features_to_embeddings(
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            self, input_features: torch.Tensor,
            audio_lens: torch.Tensor) -> torch.Tensor:
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
            batch_features = self.audio_tower(audio_features[start:end],
                                              audio_lens[start:end])
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
            batch_embeddings = self.multi_modal_projector(batch_features)
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
487
488
489
490
491
492
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
493
494
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
495
496
497
498
499
500

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            return UltravoxAudioFeatureInputs(type="audio_features",
501
502
503
                                              data=audio_features,
                                              lens=audio_lens,
                                              token_len=audio_token_len)
504
505
506
507
508
509
510
511

        if audio_embeds is not None:
            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
512
513
514
        self,
        audio_input: UltravoxAudioInputs,
    ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
515
516
517
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

518
519
520
521
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

522
523
524
        # [B1, B2] -> [B1+B2]
        audio_lens = flatten_bn(audio_input['lens'], concat=True)
        audio_token_len = flatten_bn(audio_input['token_len'], concat=True)
525
526
527
528
529
530
531

        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens)

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
532

533
534
535
536
537
538
539
540
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
            embeddings.shape[0], -1)
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

541
542
543
544
545
546
        # Return one tensor per input audio
        embed_lens = [
            token_len_item.sum().item()
            for token_len_item in audio_input['token_len']
        ]
        return flattened_embeddings.split(embed_lens)
547

548
549
550
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

551
552
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
553
554
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
555
            return []
556
557
558
559
560
561
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
562
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
563
564
565
566
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
567
    ) -> torch.Tensor:
568
569
570
571
572
573
574
575
576
577
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
578
579
580
581
582
583

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
584
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
585
586
587
588
589
590
591
592
593
594
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
595
596
597
598
599
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
600

601
        """
602

603
        if intermediate_tensors is not None:
604
            inputs_embeds = None
605

606
607
608
609
610
611
612
613
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

        hidden_states = language_model.model(input_ids,
                                             positions,
                                             intermediate_tensors,
                                             inputs_embeds=inputs_embeds)
614
615
        return hidden_states

616
617
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
618

619
620
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
621
622
623

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
624
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
625
626
627


def pad_and_concat_to_dim3(
628
    features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
649
    # Pad and concatenate:
650
651
652
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)