ultravox.py 26.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
5
from collections.abc import Iterable, Mapping, Sequence
6
from functools import cached_property
7
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
8
9
10
11

import torch
from torch import nn
from torch.nn import functional as F
12
from transformers import BatchFeature, ProcessorMixin
13
14
15
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

16
from vllm import envs
17
from vllm.config import VllmConfig
18
from vllm.forward_context import get_forward_context
19
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
20
from vllm.model_executor.layers.layernorm import RMSNorm
Joe Runde's avatar
Joe Runde committed
21
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
22
from vllm.model_executor.model_loader.loader import DefaultModelLoader
23
from vllm.model_executor.models.module_mapping import MultiModelKeys
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
27
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
28
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
29
from vllm.multimodal.processing import (BaseMultiModalProcessor,
30
31
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
32
from vllm.multimodal.profiling import BaseDummyInputsBuilder
33
from vllm.sequence import IntermediateTensors
34
35
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

36
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
37
                         SupportsMultiModal, SupportsPP)
38
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
39
                    init_vllm_registered_model, maybe_prefix,
40
                    merge_multimodal_embeddings,
41
                    merge_multimodal_embeddings_from_map)
42

43
44
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
45
_AUDIO_TOKENS_PER_SECOND = 6.25
46
_MAX_ENCODER_BATCH_SIZE = 16
47
48
49
50


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
51
    data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
52
    """Shape: `(batch_size, num_chunks, 80, M)`"""
53
    lens: Union[torch.Tensor, list[torch.Tensor]]
54
55
56
57
    """
    Length of the audio frames. Used for attention mask in WhisperEncoder.
    Shape: `(batch_size, num_chunks)`
    """
58
    token_len: Union[torch.Tensor, list[torch.Tensor]]
59
60
61
62
    """
    Length of the audio tokens. Used for flattening the audio features.
    Shape: `(batch_size, num_chunks)`
    """
63
64
65
66


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
67
    data: NestedTensors
68
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
69
70
71
72
73
74


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


75
class UltravoxProcessingInfo(BaseProcessingInfo):
76

77
    def get_hf_processor(
78
79
80
81
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
82
        **kwargs: object,
83
    ) -> ProcessorMixin:
84
        hf_processor = self.ctx.get_hf_processor(**kwargs)
85
86
87
88
89
90

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
        # token, thus we override placeholder with a reserved special
        # token.
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
91
        hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
92
        return hf_processor
93

94
    def get_feature_extractor(
95
96
97
98
99
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> WhisperFeatureExtractor:
100
        hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
101
102
103
104
105
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

106
107
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
108

109
110
111
112

class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
                                 ):

113
114
115
116
117
118
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
119
        self,
120
121
        seq_len: int,
        mm_counts: Mapping[str, int],
122
    ) -> MultiModalDataDict:
123
        feature_extractor = self.info.get_feature_extractor()
124
125

        sampling_rate = feature_extractor.sampling_rate
126
127
        audio_len = (feature_extractor.chunk_length * sampling_rate *
                     _MAX_ENCODER_BATCH_SIZE)
128
129
        num_audios = mm_counts.get("audio", 0)

130
        return {
131
132
133
134
135
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }


136
137
class UltravoxMultiModalProcessor(
        BaseMultiModalProcessor[UltravoxProcessingInfo]):
138

139
    def _get_data_parser(self) -> MultiModalDataParser:
140
        feature_extractor = self.info.get_feature_extractor()
141
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
142
143

    def _call_hf_processor(
144
145
        self,
        prompt: str,
146
147
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
148
    ) -> BatchFeature:
149
        # Text-only input not supported in composite processor
150
        if not mm_data.get("audios", []):
151
152
            prompt_ids = self.info.get_tokenizer().encode(
                prompt, add_special_tokens=False)
153
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
154
155
156
157
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
158
        assert isinstance(audios, list)
159

160
        feature_extractor = self.info.get_feature_extractor()
161
162
        mm_kwargs = dict(
            **mm_kwargs,
163
            sampling_rate=feature_extractor.sampling_rate,
164
            include_audio_num_chunks=True,
165
166
        )

167
        item_processor_data = dict(**mm_data, audios=audios)
168

169
170
171
172
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
173
        )
174
175
176
        output['audio_features'] = output.pop('audio_values')

        return output
177

178
179
180
181
182
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
183
        num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0))
184
        return dict(
185
186
187
188
189
190
191
192
193
194
195
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
            audio_features=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
196
197
198
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

199
    def _get_prompt_updates(
200
201
        self,
        mm_items: MultiModalDataItems,
202
        hf_processor_mm_kwargs: Mapping[str, Any],
203
        out_mm_kwargs: MultiModalKwargs,
204
    ) -> Sequence[PromptUpdate]:
205
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
206

207
208
209
210
211
212
213
214
215
216
217
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
        num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0))
        chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
                                                      dim=0,
                                                      dtype=torch.int32)
        chunks_start_idx = torch.cat(
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx])
218
219

        def get_replacement_ultravox(item_idx: int):
220
221
222
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
            audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum()
223
            return [replacement_id] * int(audio_token_len)  # type: ignore
224
225
226
227

        return [
            PromptReplacement(
                modality="audio",
228
                target="<|audio|>",
229
230
231
                replacement=get_replacement_ultravox,
            )
        ]
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
261
262
263
264
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
265
266

        if config.projector_act == "swiglu":
267
            self.act = MulAndSilu()
268
            dim_mid = dim_mid // 2
269
270
271
        else:
            self.act = get_act_fn(config.projector_act)

272
273
274
275
276
277
278
279
280
281
282
        dim_out = config.text_config.hidden_size
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
283
284
285
286
287
288

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
289
        hidden_states = self.ln_mid(hidden_states)
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
        return (self.config.max_source_positions * self.conv1.stride[0] *
                self.conv2.stride[0])

    def get_attention_mask_by_audio_len(self,
                                        audio_lens: Optional[torch.Tensor],
                                        hidden_states: torch.Tensor):
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
        attention_mask = torch.arange(max_seq_len,
                                      device=hidden_states.device)[None, :].lt(
                                          audio_feature_len.view(-1, 1))
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

351
352
    def forward(
        self,
353
354
        input_features: torch.Tensor,
        audio_lens: Optional[torch.Tensor] = None,
355
    ):
356
        expected_seq_length = self.max_context_length
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

375
376
377
        attention_mask = self.get_attention_mask_by_audio_len(
            audio_lens, hidden_states)

378
379
380
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
381
                attention_mask,
382
383
384
385
386
387
388
389
390
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


391
392
393
394
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
    dummy_inputs=UltravoxDummyInputsBuilder)
395
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
396
397
398
399
400
401

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

402
403
404
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

405
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
406
        super().__init__()
407
408
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
409
410
411
412
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

413
414
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
415
        if config.audio_model_id is not None:
416
417
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
418
419
420
421
422
423
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
424
425
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
426
            vllm_config=vllm_config,
427
428
429
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
430
        if config.text_model_id is not None:
431
432
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
433
434
435
436
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
437

438
439
440
441
442
443
444
445
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
446
        return get_sampler()
447

448
449
450
451
452
453
454
455
456
457
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

458
    def _audio_features_to_embeddings(
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
            self, input_features: torch.Tensor,
            audio_lens: torch.Tensor) -> torch.Tensor:
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
            batch_features = self.audio_tower(audio_features[start:end],
                                              audio_lens[start:end])
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
            batch_embeddings = self.multi_modal_projector(batch_features)
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
479
480
481
482
483
484
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
485
486
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
487
488
489
490
491
492
493
494

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")
495
496
497
498
499
500
            if not isinstance(audio_lens, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio_lens. "
                                 f"Got type: {type(audio_features)}")
            if not isinstance(audio_token_len, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio_token_len. "
                                 f"Got type: {type(audio_features)}")
501
502

            return UltravoxAudioFeatureInputs(type="audio_features",
503
504
505
                                              data=audio_features,
                                              lens=audio_lens,
                                              token_len=audio_token_len)
506
507

        if audio_embeds is not None:
508
            if not isinstance(audio_embeds, (torch.Tensor, list)):
509
510
511
512
513
514
515
516
517
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
518
519
520
        self,
        audio_input: UltravoxAudioInputs,
    ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
521
522
523
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

524
525
526
527
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

528
529
530
        # [B1, B2] -> [B1+B2]
        audio_lens = flatten_bn(audio_input['lens'], concat=True)
        audio_token_len = flatten_bn(audio_input['token_len'], concat=True)
531
532
533
534
535
536
537

        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens)

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
538

539
540
541
542
543
544
545
546
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
            embeddings.shape[0], -1)
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

547
548
549
550
551
552
        # Return one tensor per input audio
        embed_lens = [
            token_len_item.sum().item()
            for token_len_item in audio_input['token_len']
        ]
        return flattened_embeddings.split(embed_lens)
553

554
555
556
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

557
    def get_multimodal_embeddings(
558
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
559
560
561
562
563
564
565
566
567
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return None
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
568
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
569
570
571
572
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:

573
574
            # TODO(ywang96): remove this block after v0 is deprecated.
            if not envs.VLLM_USE_V1:
575
                attn_metadata = get_forward_context().attn_metadata
576
577
578
579
580
581
582
                merge_multimodal_embeddings_from_map(
                    inputs_embeds, multimodal_embeddings,
                    attn_metadata.multi_modal_placeholder_index_maps["audio"])
            else:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, multimodal_embeddings,
                    _AUDIO_PLACEHOLDER_TOKEN)
583
584
585
586
587
588
589
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
590
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
591
592
593
594
595
596
597
598
599
600
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
601
602
603
604
605
            audio_features: A batch of audio input chunks [B, N, 80, M].
            audio_lens: Length of audio frames for each audio chunk [B].
            audio_token_len: Length of audio tokens for each audio chunk [B'].
                Note: batch dim is different from batch dim in audio chunks.

606
        """
607

608
        if intermediate_tensors is not None:
609
            inputs_embeds = None
610
611
612
613
614
615
616

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            inputs_embeds = self.get_input_embeddings(input_ids,
617
                                                      multimodal_embeddings)
618
619
620
621
622
623
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
624
625
626
627
628
629
630
631
632
633
634
635
636
637
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

638
639
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
640
641
642

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
643
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
644
645
646


def pad_and_concat_to_dim3(
647
    features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
    # Pad and oncatenate:
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)