ultravox.py 28.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
import copy
8
import inspect
9
from collections.abc import Iterable, Mapping, Sequence
10
from types import SimpleNamespace
11
from typing import Annotated, Any, Literal, TypeAlias
12
13
14
15

import torch
from torch import nn
from torch.nn import functional as F
16
from transformers import BatchFeature, ProcessorMixin
17
from transformers.modeling_utils import ModuleUtilsMixin
18
from transformers.models.whisper import WhisperFeatureExtractor
19
20
21
22
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
)
23

24
from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
26
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
27
from vllm.model_executor.layers.layernorm import RMSNorm
28
from vllm.model_executor.model_loader import DefaultModelLoader
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
33
34
35
36
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
37
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
38
39
40
41
42
43
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
44
from vllm.multimodal.profiling import BaseDummyInputsBuilder
45
from vllm.sequence import IntermediateTensors
46
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
47
from vllm.utils.tensor_schema import TensorSchema, TensorShape
48

49
50
51
52
53
54
55
56
57
58
59
60
61
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
62

63
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
64
_MAX_ENCODER_BATCH_SIZE = 16
65
66


67
class UltravoxAudioFeatureInputs(TensorSchema):
68
    """
69
70
71
72
73
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
74
    """
75

76
    type: Literal["audio_features"]
77
    data: Annotated[
78
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
79
        TensorShape("bn", "nmb", "t"),
80
    ]
81
82
83
84
85
86
87
88
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
89
90
91


class UltravoxAudioEmbeddingInputs(TensorSchema):
92
    """
93
94
95
96
97
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
98
    """
99

100
    type: Literal["audio_embeds"]
101
    data: Annotated[
102
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
103
    ]
104
105


106
107
108
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
109
110


111
class UltravoxProcessingInfo(BaseProcessingInfo):
112
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
113
        config = self.ctx.model_config.hf_config
114
        hf_processor = self.ctx.get_hf_processor(**kwargs)
115
116
117

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
118
        # token, thus we override placeholder with a reserved token.
119
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
120
121
        hf_processor.audio_replacement_token_id = config.audio_token_index

122
        return hf_processor
123

124
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
125
        hf_processor = self.get_hf_processor(**kwargs)
126
127

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
128
        audio_processor = hf_processor.audio_processor  # type: ignore
129
130
131
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

132
133
134
135
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

136
137
138
139
    def get_target_channels(self) -> int:
        """Return target audio channels for Ultravox models (mono)."""
        return 1

140
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
141
        return {"audio": None}
142

143

144
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
145
146
147
148
149
150
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
151
        self,
152
153
        seq_len: int,
        mm_counts: Mapping[str, int],
154
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
155
    ) -> MultiModalDataDict:
156
        feature_extractor = self.info.get_feature_extractor()
157
158

        sampling_rate = feature_extractor.sampling_rate
159
160
161
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
162
163
        num_audios = mm_counts.get("audio", 0)

164
165
        audio_overrides = mm_options.get("audio") if mm_options else None

166
        return {
167
168
169
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
170
171
172
        }


173
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
174
    def _get_data_parser(self) -> MultiModalDataParser:
175
        feature_extractor = self.info.get_feature_extractor()
176
177
178
179
        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.info.get_target_channels(),
        )
180
181

    def _call_hf_processor(
182
183
        self,
        prompt: str,
184
185
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
186
        tok_kwargs: Mapping[str, object],
187
    ) -> BatchFeature:
188
        # Text-only input not supported in composite processor
189
        if not mm_data.get("audios", []):
190
            prompt_ids = self.info.get_tokenizer().encode(
191
192
                prompt, add_special_tokens=False
            )
193
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
194
195
196
197
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
198
        assert isinstance(audios, list)
199

200
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
201
202
        mm_kwargs = dict(
            **mm_kwargs,
203
            sampling_rate=feature_extractor.sampling_rate,
204
            include_audio_num_chunks=True,
205
206
        )

207
        item_processor_data = dict(**mm_data, audios=audios)
208

209
210
211
212
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

213
214
215
216
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
217
            tok_kwargs=tok_kwargs,
218
        )
219
        output["audio_features"] = output.pop("audio_values")
220
221

        return output
222

223
224
225
226
227
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
228
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
229
        return dict(
230
231
232
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
233
234
235
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
236
237
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
238
239
240
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

241
    def _get_prompt_updates(
242
243
        self,
        mm_items: MultiModalDataItems,
244
        hf_processor_mm_kwargs: Mapping[str, Any],
245
        out_mm_kwargs: MultiModalKwargsItems,
246
    ) -> Sequence[PromptUpdate]:
247
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
248

249
250
251
252
253
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
254
255
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
256
257
258
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
259
        chunks_start_idx = torch.cat(
260
261
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
262
263

        def get_replacement_ultravox(item_idx: int):
264
265
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
266
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
267
            return [replacement_id] * int(audio_token_len)  # type: ignore
268
269
270
271

        return [
            PromptReplacement(
                modality="audio",
272
                target="<|audio|>",
273
274
275
                replacement=get_replacement_ultravox,
            )
        ]
276
277
278
279
280
281
282
283
284
285
286
287
288
289


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
290
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
291
292
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
293
294
295
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
296
297
298
        return audio_embeds


299
class UltravoxFeedForwardProjector(nn.Module):
300
301
302
303
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
304
305
306
307
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
308
309

        if config.projector_act == "swiglu":
310
            self.act = MulAndSilu()
311
            dim_mid = dim_mid // 2
312
313
314
        else:
            self.act = get_act_fn(config.projector_act)

315
        dim_out = config.text_config.hidden_size
316
317
318
319
320
321
322
323
324
325
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
326

327
328
329
    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
330
331
332
333
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
334
        hidden_states = self.ln_mid(hidden_states)
335
336
337
338
339
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.config = SimpleNamespace(is_decoder=False)

        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim_in = config.audio_config.hidden_size * config.stack_factor

        projector_audio_config = copy.deepcopy(config.audio_config)

        self.ln_pre = RMSNorm(dim_in)
        self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)

        self.embed_positions = nn.Embedding(
            projector_audio_config.max_source_positions,
            projector_audio_config.d_model,
        )

        self.layers = nn.ModuleList(
            [
                WhisperEncoderLayer(projector_audio_config)
                for _ in range(config.num_projector_layers)
            ]
        )

        self.ln_post = RMSNorm(projector_audio_config.d_model)
        self.linear_out = nn.Linear(
            projector_audio_config.d_model, config.text_config.hidden_size
        )

    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)

        max_len_stacked = audio_features.shape[1]
        attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
            None, :
        ].lt(audio_token_len[:, None])
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, attention_mask.shape, audio_features.dtype
        )

        hidden_states = self.ln_pre(audio_features)
        hidden_states = self.linear_in(hidden_states)

        positions = self.embed_positions(
            torch.arange(hidden_states.size(1), device=hidden_states.device)
        )
        hidden_states = hidden_states + positions

391
392
393
394
395
396
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

397
398
399
400
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
401
                **kwargs,
402
403
404
405
406
407
408
409
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.linear_out(hidden_states)
        return hidden_states


410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

429
430
431
432
433
434
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
435
436
437
438
439
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
440

441
    def get_attention_mask_by_audio_len(
442
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
443
    ):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
459
460
461
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
462
463
464
465
466
467
468
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

469
470
    def forward(
        self,
471
        input_features: torch.Tensor,
472
        audio_lens: torch.Tensor | None = None,
473
    ):
474
        expected_seq_length = self.max_context_length
475
476
477
478
479
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
480
481
                f"features to {expected_seq_length}."
            )
482
483
484
485
486

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
487
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
488
489

        hidden_states = inputs_embeds + embed_pos
490
491
492
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
493

494
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
495

496
497
498
499
500
501
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

502
503
504
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
505
                attention_mask,
506
                **kwargs,
507
508
509
510
511
512
513
514
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


515
516
517
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
518
519
    dummy_inputs=UltravoxDummyInputsBuilder,
)
520
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
521
522
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
523
        "gate_up_proj": ["gate_proj", "up_proj"],
524
525
    }

526
    hf_to_vllm_mapper = WeightsMapper(
527
528
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
529

530
    @classmethod
531
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
532
533
534
535
536
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

537
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
538
        super().__init__()
539
        config: UltravoxConfig = vllm_config.model_config.hf_config
540
        multimodal_config = vllm_config.model_config.multimodal_config
541
542
543
544
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

545
546
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
547
        if config.audio_model_id is not None:
548
549
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
550
551
552
553
554
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
555
556
                )
            )
557
558
559
560
        if config.num_projector_layers > 0:
            self.multi_modal_projector = UltravoxTransformerProjector(config)
        else:
            self.multi_modal_projector = UltravoxFeedForwardProjector(config)
561
        self.language_model = init_vllm_registered_model(
562
            vllm_config=vllm_config,
563
            hf_config=config.wrapped_model_config,
564
565
            prefix=maybe_prefix(prefix, "language_model"),
        )
566
        if config.text_model_id is not None:
567
568
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
569
            self.secondary_weights.append(
570
571
572
573
574
575
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
576

577
        self.make_empty_intermediate_tensors = (
578
579
            self.language_model.make_empty_intermediate_tensors
        )
580

581
582
583
584
585
586
587
588
589
590
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

591
    def _audio_features_to_embeddings(
592
593
594
595
        self,
        input_features: torch.Tensor,
        audio_lens: torch.Tensor,
        audio_token_len: torch.Tensor,
596
    ) -> torch.Tensor:
597
598
599
600
601
602
603
604
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
605
606
607
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
608
609
610
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
611
612
613
            batch_embeddings = self.multi_modal_projector(
                batch_features, audio_token_len[start:end]
            )
614
615
616
617
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
618
619
620
        return audio_embeddings

    def _parse_and_validate_audio_input(
621
        self, **kwargs: object
622
    ) -> UltravoxAudioInputs | None:
623
624
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
625
626
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
627
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
628
629
630
631
632

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
633
634
635
636
637
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
638
                num_chunks=audio_num_chunks,
639
            )
640
641

        if audio_embeds is not None:
642
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
643
644
645
646

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
647
648
        self,
        audio_input: UltravoxAudioInputs,
649
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
650
651
652
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

653
654
655
656
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

657
658
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
659

660
661
662
        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens, audio_token_len
        )
663
664
665
666

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
667

668
669
670
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
671
672
            embeddings.shape[0], -1
        )
673
674
675
676
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

677
678
        # Return one tensor per input audio
        embed_lens = [
679
680
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
681
682
        ]
        return flattened_embeddings.split(embed_lens)
683

684
685
686
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

687
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
688
689
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
690
            return []
691
692
693
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

694
    def embed_input_ids(
695
696
        self,
        input_ids: torch.Tensor,
697
        multimodal_embeddings: MultiModalEmbeddings | None = None,
698
        *,
699
        is_multimodal: torch.Tensor | None = None,
700
701
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
702
    ) -> torch.Tensor:
703
704
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
705
            return super().embed_input_ids(input_ids)
706

707
        return super().embed_input_ids(
708
709
710
711
712
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
713

714
715
716
717
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
718
719
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
720
        **kwargs,
721
    ) -> torch.Tensor | IntermediateTensors:
722
723
724
725
726
727
728
729
730
731
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
732
733
734
735
736
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
737

738
        """
739

740
        if intermediate_tensors is not None:
741
            inputs_embeds = None
742

743
744
745
746
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

747
748
749
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
750
751
        return hidden_states

752
753
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
754

755
756
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
757
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
758
759
760


def pad_and_concat_to_dim3(
761
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
762
763
764
765
766
767
768
769
770
771
772
773
774
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
775

776
777
778
779
780
781
782
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
783
    # Pad and concatenate:
784
785
786
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)