ultravox.py 26.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6
from collections.abc import Iterable, Mapping, Sequence
7
from typing import Any, Literal, Optional, TypedDict, Union
8
9
10
11

import torch
from torch import nn
from torch.nn import functional as F
12
from transformers import BatchFeature, ProcessorMixin
13
14
15
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

16
from vllm import envs
17
from vllm.config import VllmConfig
18
from vllm.forward_context import get_forward_context
19
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
20
from vllm.model_executor.layers.layernorm import RMSNorm
21
from vllm.model_executor.model_loader import DefaultModelLoader
22
from vllm.model_executor.models.module_mapping import MultiModelKeys
23
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
26
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs, NestedTensors)
27
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
28
from vllm.multimodal.processing import (BaseMultiModalProcessor,
29
30
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
31
from vllm.multimodal.profiling import BaseDummyInputsBuilder
32
from vllm.sequence import IntermediateTensors
33
34
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

35
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
36
                         SupportsMultiModal, SupportsPP)
37
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
38
                    init_vllm_registered_model, maybe_prefix,
39
                    merge_multimodal_embeddings,
40
                    merge_multimodal_embeddings_from_map)
41

42
43
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
44
_AUDIO_TOKENS_PER_SECOND = 6.25
45
_MAX_ENCODER_BATCH_SIZE = 16
46
47
48
49


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
50
    data: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
51
    """Shape: `(batch_size, num_chunks, 80, M)`"""
52
    lens: Union[torch.Tensor, list[torch.Tensor]]
53
54
55
56
    """
    Length of the audio frames. Used for attention mask in WhisperEncoder.
    Shape: `(batch_size, num_chunks)`
    """
57
    token_len: Union[torch.Tensor, list[torch.Tensor]]
58
59
60
61
    """
    Length of the audio tokens. Used for flattening the audio features.
    Shape: `(batch_size, num_chunks)`
    """
62
63
64
65


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
66
    data: NestedTensors
67
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
68
69
70
71
72
73


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


74
class UltravoxProcessingInfo(BaseProcessingInfo):
75

76
    def get_hf_processor(
77
78
79
80
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
81
        **kwargs: object,
82
    ) -> ProcessorMixin:
83
        hf_processor = self.ctx.get_hf_processor(**kwargs)
84
85
86
87
88
89

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
        # token, thus we override placeholder with a reserved special
        # token.
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
90
        hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
91
        return hf_processor
92

93
    def get_feature_extractor(
94
95
96
97
98
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> WhisperFeatureExtractor:
99
        hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
100
101
102
103
104
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

105
106
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
107

108
109
110
111

class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
                                 ):

112
113
114
115
116
117
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
118
        self,
119
120
        seq_len: int,
        mm_counts: Mapping[str, int],
121
    ) -> MultiModalDataDict:
122
        feature_extractor = self.info.get_feature_extractor()
123
124

        sampling_rate = feature_extractor.sampling_rate
125
126
        audio_len = (feature_extractor.chunk_length * sampling_rate *
                     _MAX_ENCODER_BATCH_SIZE)
127
128
        num_audios = mm_counts.get("audio", 0)

129
        return {
130
131
132
133
134
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }


135
136
class UltravoxMultiModalProcessor(
        BaseMultiModalProcessor[UltravoxProcessingInfo]):
137

138
    def _get_data_parser(self) -> MultiModalDataParser:
139
        feature_extractor = self.info.get_feature_extractor()
140
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
141
142

    def _call_hf_processor(
143
144
        self,
        prompt: str,
145
146
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
147
        tok_kwargs: Mapping[str, object],
148
    ) -> BatchFeature:
149
        # Text-only input not supported in composite processor
150
        if not mm_data.get("audios", []):
151
152
            prompt_ids = self.info.get_tokenizer().encode(
                prompt, add_special_tokens=False)
153
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
154
155
156
157
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
158
        assert isinstance(audios, list)
159

160
        feature_extractor = self.info.get_feature_extractor()
161
162
        mm_kwargs = dict(
            **mm_kwargs,
163
            sampling_rate=feature_extractor.sampling_rate,
164
            include_audio_num_chunks=True,
165
166
        )

167
        item_processor_data = dict(**mm_data, audios=audios)
168

169
170
171
172
        # some tokenizer kwargs are incompatible with UltravoxProcessor
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

173
174
175
176
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
177
            tok_kwargs=tok_kwargs,
178
        )
179
180
181
        output['audio_features'] = output.pop('audio_values')

        return output
182

183
184
185
186
187
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
188
        num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0))
189
        return dict(
190
191
192
193
194
195
196
197
198
199
200
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
            audio_features=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
201
202
203
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

204
    def _get_prompt_updates(
205
206
        self,
        mm_items: MultiModalDataItems,
207
        hf_processor_mm_kwargs: Mapping[str, Any],
208
        out_mm_kwargs: MultiModalKwargs,
209
    ) -> Sequence[PromptUpdate]:
210
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
211

212
213
214
215
216
217
218
219
220
221
222
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
        num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0))
        chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
                                                      dim=0,
                                                      dtype=torch.int32)
        chunks_start_idx = torch.cat(
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx])
223
224

        def get_replacement_ultravox(item_idx: int):
225
226
227
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
            audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum()
228
            return [replacement_id] * int(audio_token_len)  # type: ignore
229
230
231
232

        return [
            PromptReplacement(
                modality="audio",
233
                target="<|audio|>",
234
235
236
                replacement=get_replacement_ultravox,
            )
        ]
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
266
267
268
269
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
270
271

        if config.projector_act == "swiglu":
272
            self.act = MulAndSilu()
273
            dim_mid = dim_mid // 2
274
275
276
        else:
            self.act = get_act_fn(config.projector_act)

277
278
279
280
281
282
283
284
285
286
287
        dim_out = config.text_config.hidden_size
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
288
289
290
291
292
293

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
294
        hidden_states = self.ln_mid(hidden_states)
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
        return (self.config.max_source_positions * self.conv1.stride[0] *
                self.conv2.stride[0])

    def get_attention_mask_by_audio_len(self,
                                        audio_lens: Optional[torch.Tensor],
                                        hidden_states: torch.Tensor):
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
        attention_mask = torch.arange(max_seq_len,
                                      device=hidden_states.device)[None, :].lt(
                                          audio_feature_len.view(-1, 1))
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

356
357
    def forward(
        self,
358
359
        input_features: torch.Tensor,
        audio_lens: Optional[torch.Tensor] = None,
360
    ):
361
        expected_seq_length = self.max_context_length
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

380
381
382
        attention_mask = self.get_attention_mask_by_audio_len(
            audio_lens, hidden_states)

383
384
385
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
386
                attention_mask,
387
388
389
390
391
392
393
394
395
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


396
397
398
399
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
    dummy_inputs=UltravoxDummyInputsBuilder)
400
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
401
402
403
404
405
406

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

407
408
409
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

410
411
412
413
414
415
416
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

417
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
418
        super().__init__()
419
420
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
421
422
423
424
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

425
426
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
427
        if config.audio_model_id is not None:
428
429
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
430
431
432
433
434
435
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
436
437
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
438
            vllm_config=vllm_config,
439
440
441
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
442
        if config.text_model_id is not None:
443
444
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
445
446
447
448
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
449

450
451
452
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

453
454
455
456
457
458
459
460
461
462
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

463
    def _audio_features_to_embeddings(
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
            self, input_features: torch.Tensor,
            audio_lens: torch.Tensor) -> torch.Tensor:
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
            batch_features = self.audio_tower(audio_features[start:end],
                                              audio_lens[start:end])
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
            batch_embeddings = self.multi_modal_projector(batch_features)
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
484
485
486
487
488
489
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
490
491
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
492
493
494
495
496
497
498
499

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")
500
501
502
503
504
505
            if not isinstance(audio_lens, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio_lens. "
                                 f"Got type: {type(audio_features)}")
            if not isinstance(audio_token_len, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio_token_len. "
                                 f"Got type: {type(audio_features)}")
506
507

            return UltravoxAudioFeatureInputs(type="audio_features",
508
509
510
                                              data=audio_features,
                                              lens=audio_lens,
                                              token_len=audio_token_len)
511
512

        if audio_embeds is not None:
513
            if not isinstance(audio_embeds, (torch.Tensor, list)):
514
515
516
517
518
519
520
521
522
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
523
524
525
        self,
        audio_input: UltravoxAudioInputs,
    ) -> Union[NestedTensors, tuple[torch.Tensor, ...]]:
526
527
528
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

529
530
531
532
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

533
534
535
        # [B1, B2] -> [B1+B2]
        audio_lens = flatten_bn(audio_input['lens'], concat=True)
        audio_token_len = flatten_bn(audio_input['token_len'], concat=True)
536
537
538
539
540
541
542

        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens)

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
543

544
545
546
547
548
549
550
551
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
            embeddings.shape[0], -1)
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

552
553
554
555
556
557
        # Return one tensor per input audio
        embed_lens = [
            token_len_item.sum().item()
            for token_len_item in audio_input['token_len']
        ]
        return flattened_embeddings.split(embed_lens)
558

559
560
561
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

562
563
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
564
565
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
566
            return []
567
568
569
570
571
572
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
573
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
574
575
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
576
577
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
578

579
580
            # TODO(ywang96): remove this block after v0 is deprecated.
            if not envs.VLLM_USE_V1:
581
                attn_metadata = get_forward_context().attn_metadata
582
583
584
585
586
587
588
                merge_multimodal_embeddings_from_map(
                    inputs_embeds, multimodal_embeddings,
                    attn_metadata.multi_modal_placeholder_index_maps["audio"])
            else:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, multimodal_embeddings,
                    _AUDIO_PLACEHOLDER_TOKEN)
589
590
591
592
593
594
595
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
596
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
597
598
599
600
601
602
603
604
605
606
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
607
608
609
610
611
            audio_features: A batch of audio input chunks [B, N, 80, M].
            audio_lens: Length of audio frames for each audio chunk [B].
            audio_token_len: Length of audio tokens for each audio chunk [B'].
                Note: batch dim is different from batch dim in audio chunks.

612
        """
613

614
        if intermediate_tensors is not None:
615
            inputs_embeds = None
616
617
618
619
620
621
622

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            inputs_embeds = self.get_input_embeddings(input_ids,
623
                                                      multimodal_embeddings)
624
625
626
627
628
629
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
630
631
632
633
634
635
636
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

637
638
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
639
640
641

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
642
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
643
644
645


def pad_and_concat_to_dim3(
646
    features: Union[torch.Tensor, list[torch.Tensor], list[list[torch.Tensor]]]
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
    # Pad and oncatenate:
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)