ultravox.py 28.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
import copy
8
import inspect
9
from collections.abc import Iterable, Mapping, Sequence
10
from types import SimpleNamespace
11
from typing import Annotated, Any, Literal, TypeAlias
12
13
14
15

import torch
from torch import nn
from torch.nn import functional as F
16
from transformers import BatchFeature, ProcessorMixin
17
from transformers.modeling_utils import ModuleUtilsMixin
18
from transformers.models.whisper import WhisperFeatureExtractor
19
20
21
22
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
)
23

24
from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
26
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
27
from vllm.model_executor.layers.layernorm import RMSNorm
28
from vllm.model_executor.model_loader import DefaultModelLoader
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
33
34
35
36
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
37
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
38
from vllm.multimodal.processing import (
39
    BaseDummyInputsBuilder,
40
41
42
43
44
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
45
from vllm.sequence import IntermediateTensors
46
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
47
from vllm.utils.tensor_schema import TensorSchema, TensorShape
48

49
50
51
52
53
54
55
56
57
58
59
60
61
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
62

63
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
64
_MAX_ENCODER_BATCH_SIZE = 16
65
66


67
class UltravoxAudioFeatureInputs(TensorSchema):
68
    """
69
70
71
72
73
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
74
    """
75

76
    type: Literal["audio_features"]
77
    data: Annotated[
78
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
79
        TensorShape("bn", "nmb", "t"),
80
    ]
81
82
83
84
85
86
87
88
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
89
90
91


class UltravoxAudioEmbeddingInputs(TensorSchema):
92
    """
93
94
95
96
97
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
98
    """
99

100
    type: Literal["audio_embeds"]
101
    data: Annotated[
102
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
103
    ]
104
105


106
107
108
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
109
110


111
class UltravoxProcessingInfo(BaseProcessingInfo):
112
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
113
        config = self.ctx.model_config.hf_config
114
        hf_processor = self.ctx.get_hf_processor(**kwargs)
115
116
117

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
118
        # token, thus we override placeholder with a reserved token.
119
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
120
121
        hf_processor.audio_replacement_token_id = config.audio_token_index

122
        return hf_processor
123

124
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
125
        hf_processor = self.get_hf_processor(**kwargs)
126
127

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
128
        audio_processor = hf_processor.audio_processor  # type: ignore
129
130
131
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

132
133
134
135
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

136
137
138
139
    def get_target_channels(self) -> int:
        """Return target audio channels for Ultravox models (mono)."""
        return 1

140
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
141
        return {"audio": None}
142

143

144
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
145
146
147
148
149
150
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
151
        self,
152
153
        seq_len: int,
        mm_counts: Mapping[str, int],
154
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
155
    ) -> MultiModalDataDict:
156
        feature_extractor = self.info.get_feature_extractor()
157
158

        sampling_rate = feature_extractor.sampling_rate
159
160
161
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
162
163
        num_audios = mm_counts.get("audio", 0)

164
165
        audio_overrides = mm_options.get("audio") if mm_options else None

166
        return {
167
168
169
            "audio": self._get_dummy_audios(
                length=audio_len, num_audios=num_audios, overrides=audio_overrides
            )
170
171
172
        }


173
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
174
    def _get_data_parser(self) -> MultiModalDataParser:
175
        feature_extractor = self.info.get_feature_extractor()
176
177
178
179
        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.info.get_target_channels(),
        )
180
181

    def _call_hf_processor(
182
183
        self,
        prompt: str,
184
185
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
186
        tok_kwargs: Mapping[str, object],
187
    ) -> BatchFeature:
188
        # Text-only input not supported in composite processor
189
        if not mm_data.get("audios", []):
190
            prompt_ids = self.info.get_tokenizer().encode(
191
192
                prompt, add_special_tokens=False
            )
193
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
194
195
196
197
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
198
        assert isinstance(audios, list)
199

200
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
201
202
        mm_kwargs = dict(
            **mm_kwargs,
203
            sampling_rate=feature_extractor.sampling_rate,
204
            include_audio_num_chunks=True,
205
206
        )

207
        item_processor_data = dict(**mm_data, audios=audios)
208

209
210
211
212
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

213
214
215
216
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
217
            tok_kwargs=tok_kwargs,
218
        )
219
        output["audio_features"] = output.pop("audio_values")
220
221

        return output
222

223
224
225
226
227
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
228
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
229
        return dict(
230
231
232
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
233
234
235
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
236
237
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
238
239
240
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

241
    def _get_prompt_updates(
242
243
        self,
        mm_items: MultiModalDataItems,
244
        hf_processor_mm_kwargs: Mapping[str, Any],
245
        out_mm_kwargs: MultiModalKwargsItems,
246
    ) -> Sequence[PromptUpdate]:
247
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
248

249
250
251
252
253
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
254
255
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
256
257
258
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
259
        chunks_start_idx = torch.cat(
260
261
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
262
263

        def get_replacement_ultravox(item_idx: int):
264
265
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
266
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
267
            return [replacement_id] * int(audio_token_len)  # type: ignore
268
269
270
271

        return [
            PromptReplacement(
                modality="audio",
272
                target="<|audio|>",
273
274
275
                replacement=get_replacement_ultravox,
            )
        ]
276
277
278
279
280
281
282
283
284
285
286
287
288
289


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
290
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
291
292
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
293
294
295
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
296
297
298
        return audio_embeds


299
class UltravoxFeedForwardProjector(nn.Module):
300
301
302
303
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
304
305
306
307
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
308
309

        if config.projector_act == "swiglu":
310
            self.act = MulAndSilu()
311
            dim_mid = dim_mid // 2
312
313
314
        else:
            self.act = get_act_fn(config.projector_act)

315
        dim_out = config.text_config.hidden_size
316
317
318
319
320
321
322
323
324
325
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
326

327
328
329
    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
330
331
332
333
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
334
        hidden_states = self.ln_mid(hidden_states)
335
336
337
338
339
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.config = SimpleNamespace(is_decoder=False)

        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim_in = config.audio_config.hidden_size * config.stack_factor

        projector_audio_config = copy.deepcopy(config.audio_config)

        self.ln_pre = RMSNorm(dim_in)
        self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)

        self.embed_positions = nn.Embedding(
            projector_audio_config.max_source_positions,
            projector_audio_config.d_model,
        )

        self.layers = nn.ModuleList(
            [
                WhisperEncoderLayer(projector_audio_config)
                for _ in range(config.num_projector_layers)
            ]
        )

        self.ln_post = RMSNorm(projector_audio_config.d_model)
        self.linear_out = nn.Linear(
            projector_audio_config.d_model, config.text_config.hidden_size
        )

    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)

        max_len_stacked = audio_features.shape[1]
        attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
            None, :
        ].lt(audio_token_len[:, None])
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, attention_mask.shape, audio_features.dtype
        )

        hidden_states = self.ln_pre(audio_features)
        hidden_states = self.linear_in(hidden_states)

        positions = self.embed_positions(
            torch.arange(hidden_states.size(1), device=hidden_states.device)
        )
        hidden_states = hidden_states + positions

391
392
393
394
395
396
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

397
398
399
400
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
401
                **kwargs,
402
403
404
405
406
407
408
409
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.linear_out(hidden_states)
        return hidden_states


410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

429
430
431
432
433
434
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
435
436
437
438
439
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
440

441
    def get_attention_mask_by_audio_len(
442
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
443
    ):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
459
460
461
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
462
463
464
465
466
467
468
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

469
470
    def forward(
        self,
471
        input_features: torch.Tensor,
472
        audio_lens: torch.Tensor | None = None,
473
    ):
474
        expected_seq_length = self.max_context_length
475
476
477
478
479
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
480
481
                f"features to {expected_seq_length}."
            )
482
483
484
485
486

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
487
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
488
489

        hidden_states = inputs_embeds + embed_pos
490
491
492
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
493

494
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
495

496
497
498
499
500
501
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

502
503
504
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
505
                attention_mask,
506
                **kwargs,
507
508
509
510
511
512
513
514
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


515
516
517
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
518
519
    dummy_inputs=UltravoxDummyInputsBuilder,
)
520
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
521
522
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
523
        "gate_up_proj": ["gate_proj", "up_proj"],
524
525
    }

526
    hf_to_vllm_mapper = WeightsMapper(
527
528
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
529

530
    @classmethod
531
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
532
533
534
535
536
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

537
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
538
        super().__init__()
539
        config: UltravoxConfig = vllm_config.model_config.hf_config
540
        multimodal_config = vllm_config.model_config.multimodal_config
541
542
543
544
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

545
        self.secondary_weights = []
546
        if config.audio_model_id is not None:
547
548
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
549
550
551
552
553
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
554
555
                )
            )
556
        if config.text_model_id is not None:
557
558
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
559
            self.secondary_weights.append(
560
561
562
563
564
565
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
566

567
568
569
570
571
572
573
574
575
576
577
578
579
580
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
            if config.num_projector_layers > 0:
                self.multi_modal_projector = UltravoxTransformerProjector(config)
            else:
                self.multi_modal_projector = UltravoxFeedForwardProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.wrapped_model_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

581
        self.make_empty_intermediate_tensors = (
582
583
            self.language_model.make_empty_intermediate_tensors
        )
584

585
586
587
588
589
590
591
592
593
594
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

595
    def _audio_features_to_embeddings(
596
597
598
599
        self,
        input_features: torch.Tensor,
        audio_lens: torch.Tensor,
        audio_token_len: torch.Tensor,
600
    ) -> torch.Tensor:
601
602
603
604
605
606
607
608
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
609
610
611
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
612
613
614
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
615
616
617
            batch_embeddings = self.multi_modal_projector(
                batch_features, audio_token_len[start:end]
            )
618
619
620
621
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
622
623
624
        return audio_embeddings

    def _parse_and_validate_audio_input(
625
        self, **kwargs: object
626
    ) -> UltravoxAudioInputs | None:
627
628
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
629
630
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
631
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
632
633
634
635
636

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
637
638
639
640
641
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
642
                num_chunks=audio_num_chunks,
643
            )
644
645

        if audio_embeds is not None:
646
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
647
648
649
650

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
651
652
        self,
        audio_input: UltravoxAudioInputs,
653
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
654
655
656
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

657
658
659
660
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

661
662
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
663

664
665
666
        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens, audio_token_len
        )
667
668
669
670

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
671

672
673
674
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
675
676
            embeddings.shape[0], -1
        )
677
678
679
680
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

681
682
        # Return one tensor per input audio
        embed_lens = [
683
684
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
685
686
        ]
        return flattened_embeddings.split(embed_lens)
687

688
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
689
690
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
691
            return []
692
693
694
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

695
    def embed_input_ids(
696
697
        self,
        input_ids: torch.Tensor,
698
        multimodal_embeddings: MultiModalEmbeddings | None = None,
699
        *,
700
        is_multimodal: torch.Tensor | None = None,
701
702
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
703
    ) -> torch.Tensor:
704
705
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
706
            return super().embed_input_ids(input_ids)
707

708
        return super().embed_input_ids(
709
710
711
712
713
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
714

715
716
    def forward(
        self,
717
        input_ids: torch.Tensor | None,
718
        positions: torch.Tensor,
719
720
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
721
        **kwargs,
722
    ) -> torch.Tensor | IntermediateTensors:
723
724
725
726
727
728
729
730
731
732
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
733
734
735
736
737
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
738

739
        """
740

741
        if intermediate_tensors is not None:
742
            inputs_embeds = None
743

744
745
746
747
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

748
749
750
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
751
752
        return hidden_states

753
754
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
755

756
757
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
758
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
759
760
761


def pad_and_concat_to_dim3(
762
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
763
764
765
766
767
768
769
770
771
772
773
774
775
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
776

777
778
779
780
781
782
783
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
784
    # Pad and concatenate:
785
786
787
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)