ultravox.py 28.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
import copy
8
import inspect
9
from collections.abc import Iterable, Mapping, Sequence
10
from types import SimpleNamespace
11
from typing import Annotated, Any, Literal, TypeAlias
12
13
14
15

import torch
from torch import nn
from torch.nn import functional as F
16
from transformers import BatchFeature, ProcessorMixin
17
from transformers.modeling_utils import ModuleUtilsMixin
18
from transformers.models.whisper import WhisperFeatureExtractor
19
20
21
22
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
)
23

24
from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
26
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
27
from vllm.model_executor.layers.layernorm import RMSNorm
28
from vllm.model_executor.model_loader import DefaultModelLoader
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
33
34
35
36
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
37
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
38
from vllm.multimodal.processing import (
39
    BaseDummyInputsBuilder,
40
41
42
43
44
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
45
from vllm.sequence import IntermediateTensors
46
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
47
from vllm.utils.tensor_schema import TensorSchema, TensorShape
48

49
50
51
52
53
54
55
56
57
58
59
60
61
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
62

63
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
64
_MAX_ENCODER_BATCH_SIZE = 16
65
66


67
class UltravoxAudioFeatureInputs(TensorSchema):
68
    """
69
70
71
72
73
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
74
    """
75

76
    type: Literal["audio_features"]
77
    data: Annotated[
78
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
79
        TensorShape("bn", "nmb", "t"),
80
    ]
81
82
83
84
85
86
87
88
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
89
90
91


class UltravoxAudioEmbeddingInputs(TensorSchema):
92
    """
93
94
95
96
97
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
98
    """
99

100
    type: Literal["audio_embeds"]
101
    data: Annotated[
102
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
103
    ]
104
105


106
107
108
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
109
110


111
class UltravoxProcessingInfo(BaseProcessingInfo):
112
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
113
        config = self.ctx.model_config.hf_config
114
        hf_processor = self.ctx.get_hf_processor(**kwargs)
115
116
117

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
118
        # token, thus we override placeholder with a reserved token.
119
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
120
121
        hf_processor.audio_replacement_token_id = config.audio_token_index

122
        return hf_processor
123

124
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
125
        hf_processor = self.get_hf_processor(**kwargs)
126
127

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
128
        audio_processor = hf_processor.audio_processor  # type: ignore
129
130
131
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

132
133
134
135
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

136
137
138
139
140
141
142
143
144
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

145
146
147
148
    def get_target_channels(self) -> int:
        """Return target audio channels for Ultravox models (mono)."""
        return 1

149
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
150
        return {"audio": None}
151

152

153
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
154
155
156
157
158
159
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
160
        self,
161
162
        seq_len: int,
        mm_counts: Mapping[str, int],
163
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
164
    ) -> MultiModalDataDict:
165
        feature_extractor = self.info.get_feature_extractor()
166
167

        sampling_rate = feature_extractor.sampling_rate
168
169
170
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
171
172
        num_audios = mm_counts.get("audio", 0)

173
174
        audio_overrides = mm_options.get("audio") if mm_options else None

175
        return {
176
177
178
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
179
180
181
        }


182
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
183
    def _call_hf_processor(
184
185
        self,
        prompt: str,
186
187
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
188
        tok_kwargs: Mapping[str, object],
189
    ) -> BatchFeature:
190
        # Text-only input not supported in composite processor
191
        if not mm_data.get("audios", []):
192
            prompt_ids = self.info.get_tokenizer().encode(
193
194
                prompt, add_special_tokens=False
            )
195
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
196
197
198
199
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
200
        assert isinstance(audios, list)
201

202
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
203
204
        mm_kwargs = dict(
            **mm_kwargs,
205
            sampling_rate=feature_extractor.sampling_rate,
206
            include_audio_num_chunks=True,
207
208
        )

209
        item_processor_data = dict(**mm_data, audios=audios)
210

211
212
213
214
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

215
216
217
218
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
219
            tok_kwargs=tok_kwargs,
220
        )
221
        output["audio_features"] = output.pop("audio_values")
222
223

        return output
224

225
226
227
228
229
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
230
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
231
        return dict(
232
233
234
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
235
236
237
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
238
239
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
240
241
242
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

243
    def _get_prompt_updates(
244
245
        self,
        mm_items: MultiModalDataItems,
246
        hf_processor_mm_kwargs: Mapping[str, Any],
247
        out_mm_kwargs: MultiModalKwargsItems,
248
    ) -> Sequence[PromptUpdate]:
249
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
250

251
252
253
254
255
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
256
257
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
258
259
260
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
261
        chunks_start_idx = torch.cat(
262
263
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
264
265

        def get_replacement_ultravox(item_idx: int):
266
267
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
268
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
269
            return [replacement_id] * int(audio_token_len)  # type: ignore
270
271
272
273

        return [
            PromptReplacement(
                modality="audio",
274
                target="<|audio|>",
275
276
277
                replacement=get_replacement_ultravox,
            )
        ]
278
279
280
281
282
283
284
285
286
287
288
289
290
291


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
292
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
293
294
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
295
296
297
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
298
299
300
        return audio_embeds


301
class UltravoxFeedForwardProjector(nn.Module):
302
303
304
305
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
306
307
308
309
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
310
311

        if config.projector_act == "swiglu":
312
            self.act = MulAndSilu()
313
            dim_mid = dim_mid // 2
314
315
316
        else:
            self.act = get_act_fn(config.projector_act)

317
        dim_out = config.text_config.hidden_size
318
319
320
321
322
323
324
325
326
327
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
328

329
330
331
    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
332
333
334
335
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
336
        hidden_states = self.ln_mid(hidden_states)
337
338
339
340
341
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.config = SimpleNamespace(is_decoder=False)

        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim_in = config.audio_config.hidden_size * config.stack_factor

        projector_audio_config = copy.deepcopy(config.audio_config)

        self.ln_pre = RMSNorm(dim_in)
        self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)

        self.embed_positions = nn.Embedding(
            projector_audio_config.max_source_positions,
            projector_audio_config.d_model,
        )

        self.layers = nn.ModuleList(
            [
                WhisperEncoderLayer(projector_audio_config)
                for _ in range(config.num_projector_layers)
            ]
        )

        self.ln_post = RMSNorm(projector_audio_config.d_model)
        self.linear_out = nn.Linear(
            projector_audio_config.d_model, config.text_config.hidden_size
        )

    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)

        max_len_stacked = audio_features.shape[1]
        attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
            None, :
        ].lt(audio_token_len[:, None])
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, attention_mask.shape, audio_features.dtype
        )

        hidden_states = self.ln_pre(audio_features)
        hidden_states = self.linear_in(hidden_states)

        positions = self.embed_positions(
            torch.arange(hidden_states.size(1), device=hidden_states.device)
        )
        hidden_states = hidden_states + positions

393
394
395
396
397
398
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

399
400
401
402
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
403
                **kwargs,
404
405
406
407
408
409
410
411
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.linear_out(hidden_states)
        return hidden_states


412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

431
432
433
434
435
436
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
437
438
439
440
441
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
442

443
    def get_attention_mask_by_audio_len(
444
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
445
    ):
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
461
462
463
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
464
465
466
467
468
469
470
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

471
472
    def forward(
        self,
473
        input_features: torch.Tensor,
474
        audio_lens: torch.Tensor | None = None,
475
    ):
476
        expected_seq_length = self.max_context_length
477
478
479
480
481
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
482
483
                f"features to {expected_seq_length}."
            )
484
485
486
487
488

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
489
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
490
491

        hidden_states = inputs_embeds + embed_pos
492
493
494
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
495

496
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
497

498
499
500
501
502
503
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

504
505
506
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
507
                attention_mask,
508
                **kwargs,
509
510
511
512
513
514
515
516
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


517
518
519
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
520
521
    dummy_inputs=UltravoxDummyInputsBuilder,
)
522
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
523
524
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
525
        "gate_up_proj": ["gate_proj", "up_proj"],
526
527
    }

528
    hf_to_vllm_mapper = WeightsMapper(
529
530
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
531

532
    @classmethod
533
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
534
535
536
537
538
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

539
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
540
        super().__init__()
541
        config: UltravoxConfig = vllm_config.model_config.hf_config
542
        multimodal_config = vllm_config.model_config.multimodal_config
543
544
545
546
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

547
        self.secondary_weights = []
548
        if config.audio_model_id is not None:
549
550
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
551
552
553
554
555
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
556
557
                )
            )
558
        if config.text_model_id is not None:
559
560
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
561
            self.secondary_weights.append(
562
563
564
565
566
567
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
568

569
570
571
572
573
574
575
576
577
578
579
580
581
582
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
            if config.num_projector_layers > 0:
                self.multi_modal_projector = UltravoxTransformerProjector(config)
            else:
                self.multi_modal_projector = UltravoxFeedForwardProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.wrapped_model_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

583
        self.make_empty_intermediate_tensors = (
584
585
            self.language_model.make_empty_intermediate_tensors
        )
586

587
588
589
590
591
592
593
594
595
596
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

597
    def _audio_features_to_embeddings(
598
599
600
601
        self,
        input_features: torch.Tensor,
        audio_lens: torch.Tensor,
        audio_token_len: torch.Tensor,
602
    ) -> torch.Tensor:
603
604
605
606
607
608
609
610
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
611
612
613
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
614
615
616
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
617
618
619
            batch_embeddings = self.multi_modal_projector(
                batch_features, audio_token_len[start:end]
            )
620
621
622
623
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
624
625
626
        return audio_embeddings

    def _parse_and_validate_audio_input(
627
        self, **kwargs: object
628
    ) -> UltravoxAudioInputs | None:
629
630
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
631
632
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
633
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
634
635
636
637
638

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
639
640
641
642
643
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
644
                num_chunks=audio_num_chunks,
645
            )
646
647

        if audio_embeds is not None:
648
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
649
650
651
652

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
653
654
        self,
        audio_input: UltravoxAudioInputs,
655
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
656
657
658
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

659
660
661
662
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

663
664
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
665

666
667
668
        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens, audio_token_len
        )
669
670
671
672

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
673

674
675
676
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
677
678
            embeddings.shape[0], -1
        )
679
680
681
682
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

683
684
        # Return one tensor per input audio
        embed_lens = [
685
686
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
687
688
        ]
        return flattened_embeddings.split(embed_lens)
689

690
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
691
692
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
693
            return []
694
695
696
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

697
    def embed_input_ids(
698
699
        self,
        input_ids: torch.Tensor,
700
        multimodal_embeddings: MultiModalEmbeddings | None = None,
701
        *,
702
        is_multimodal: torch.Tensor | None = None,
703
704
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
705
    ) -> torch.Tensor:
706
707
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
708
            return super().embed_input_ids(input_ids)
709

710
        return super().embed_input_ids(
711
712
713
714
715
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
716

717
718
    def forward(
        self,
719
        input_ids: torch.Tensor | None,
720
        positions: torch.Tensor,
721
722
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
723
        **kwargs,
724
    ) -> torch.Tensor | IntermediateTensors:
725
726
727
728
729
730
731
732
733
734
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
735
736
737
738
739
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
740

741
        """
742

743
        if intermediate_tensors is not None:
744
            inputs_embeds = None
745

746
747
748
749
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

750
751
752
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
753
754
        return hidden_states

755
756
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
757

758
759
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
760
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
761
762
763


def pad_and_concat_to_dim3(
764
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
765
766
767
768
769
770
771
772
773
774
775
776
777
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
778

779
780
781
782
783
784
785
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
786
    # Pad and concatenate:
787
788
789
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)