ultravox.py 26.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
6
from collections.abc import Iterable, Mapping, Sequence
7
from functools import cached_property
8
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
9
10
11
12
13

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
14
from transformers import BatchFeature, ProcessorMixin
15
16
17
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

18
from vllm import envs
19
from vllm.config import VllmConfig
20
from vllm.forward_context import get_forward_context
21
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
22
from vllm.model_executor.layers.layernorm import RMSNorm
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
from vllm.model_executor.model_loader.loader import DefaultModelLoader
25
from vllm.model_executor.models.module_mapping import MultiModelKeys
26
from vllm.model_executor.sampling_metadata import SamplingMetadata
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
29
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
30
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
31
from vllm.multimodal.processing import (BaseMultiModalProcessor,
32
33
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate)
34
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
35
from vllm.sequence import IntermediateTensors
36
37
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

38
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
39
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
40
                    init_vllm_registered_model, maybe_prefix,
41
                    merge_multimodal_embeddings,
42
                    merge_multimodal_embeddings_from_map)
43

44
45
_AUDIO_PLACEHOLDER_OVERRIDE = "<|reserved_special_token_0|>"
_AUDIO_PLACEHOLDER_TOKEN = 128002
46
_AUDIO_TOKENS_PER_SECOND = 6.25
47
_MAX_ENCODER_BATCH_SIZE = 16
48
49
50
51


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
52
    data: NestedTensors
53
54
55
56
57
58
59
60
61
62
63
    """Shape: `(batch_size, num_chunks, 80, M)`"""
    lens: NestedTensors
    """
    Length of the audio frames. Used for attention mask in WhisperEncoder.
    Shape: `(batch_size, num_chunks)`
    """
    token_len: NestedTensors
    """
    Length of the audio tokens. Used for flattening the audio features.
    Shape: `(batch_size, num_chunks)`
    """
64
65
66
67


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
68
    data: NestedTensors
69
    """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)`"""
70
71
72
73
74
75


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


76
class UltravoxProcessingInfo(BaseProcessingInfo):
77

78
    def get_hf_processor(
79
80
81
82
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
83
        **kwargs: object,
84
    ) -> ProcessorMixin:
85
        hf_processor = self.ctx.get_hf_processor(**kwargs)
86
87
88
89
90
91

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
        # token, thus we override placeholder with a reserved special
        # token.
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
92
        hf_processor.audio_replacement_token_id = _AUDIO_PLACEHOLDER_TOKEN
93
        return hf_processor
94

95
    def get_feature_extractor(
96
97
98
99
100
        self,
        *,
        # Ignored in initialization
        sampling_rate: Optional[int] = None,
    ) -> WhisperFeatureExtractor:
101
        hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
102
103
104
105
106
        audio_processor = hf_processor.audio_processor  # type: ignore
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

107
108
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None}
109

110
111
112
113
114
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
115
        feature_extractor = self.get_feature_extractor()
116
117
        max_audio_tokens = math.ceil(feature_extractor.chunk_length *
                                     _AUDIO_TOKENS_PER_SECOND)
118

119
        return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE}
120

121
122
123
124

class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
                                 ):

125
    def get_dummy_processor_inputs(
126
        self,
127
128
129
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
130
        feature_extractor = self.info.get_feature_extractor()
131
132

        sampling_rate = feature_extractor.sampling_rate
133
134
        audio_len = (feature_extractor.chunk_length * sampling_rate *
                     _MAX_ENCODER_BATCH_SIZE)
135
136
137
138
139
140
141
142
143
144
145
        num_audios = mm_counts.get("audio", 0)

        mm_data = {
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }

        return ProcessorInputs(
            prompt_text="<|audio|>" * num_audios,
            mm_data=mm_data,
        )
146

147

148
149
class UltravoxMultiModalProcessor(
        BaseMultiModalProcessor[UltravoxProcessingInfo]):
150

151
    def _get_data_parser(self) -> MultiModalDataParser:
152
        feature_extractor = self.info.get_feature_extractor()
153
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
154
155

    def _call_hf_processor(
156
157
        self,
        prompt: str,
158
159
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
160
    ) -> BatchFeature:
161
        # Text-only input not supported in composite processor
162
        if not mm_data or not mm_data.get("audios", []):
163
164
            prompt_ids = self.info.get_tokenizer().encode(
                prompt, add_special_tokens=False)
165
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
166
167
168
169
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
170
        assert isinstance(audios, list)
171

172
        feature_extractor = self.info.get_feature_extractor()
173
174
        mm_kwargs = dict(
            **mm_kwargs,
175
            sampling_rate=feature_extractor.sampling_rate,
176
            include_audio_num_chunks=True,
177
178
        )

179
        item_processor_data = dict(**mm_data, audios=audios)
180

181
182
183
184
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
185
        )
186
187
188
        output['audio_features'] = output.pop('audio_values')

        return output
189

190
191
192
193
194
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
195
        num_chunks = hf_inputs.get('audio_num_chunks', torch.zeros(0))
196
        return dict(
197
198
199
200
201
202
203
204
205
206
207
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
            audio_features=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes(
                "audio", num_chunks),
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
208
209
210
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

211
    def _get_prompt_updates(
212
213
        self,
        mm_items: MultiModalDataItems,
214
        hf_processor_mm_kwargs: Mapping[str, Any],
215
        out_mm_kwargs: MultiModalKwargs,
216
    ) -> Sequence[PromptUpdate]:
217
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
218

219
220
221
222
223
224
225
226
227
228
229
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
        num_chunks = out_mm_kwargs.get("audio_num_chunks", torch.zeros(0))
        chunks_start_idx: torch.Tensor = torch.cumsum(num_chunks,
                                                      dim=0,
                                                      dtype=torch.int32)
        chunks_start_idx = torch.cat(
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx])
230
231

        def get_replacement_ultravox(item_idx: int):
232
233
234
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
            audio_token_len = out_mm_kwargs["audio_token_len"][start:end].sum()
235
            return [replacement_id] * int(audio_token_len)  # type: ignore
236
237
238
239

        return [
            PromptReplacement(
                modality="audio",
240
                target="<|audio|>",
241
242
243
                replacement=get_replacement_ultravox,
            )
        ]
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
273
274
275
276
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
277
278

        if config.projector_act == "swiglu":
279
            self.act = MulAndSilu()
280
            dim_mid = dim_mid // 2
281
282
283
        else:
            self.act = get_act_fn(config.projector_act)

284
285
286
287
288
289
290
291
292
293
294
        dim_out = config.text_config.hidden_size
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
295
296
297
298
299
300

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
301
        hidden_states = self.ln_mid(hidden_states)
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
        return (self.config.max_source_positions * self.conv1.stride[0] *
                self.conv2.stride[0])

    def get_attention_mask_by_audio_len(self,
                                        audio_lens: Optional[torch.Tensor],
                                        hidden_states: torch.Tensor):
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
        attention_mask = torch.arange(max_seq_len,
                                      device=hidden_states.device)[None, :].lt(
                                          audio_feature_len.view(-1, 1))
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

363
364
    def forward(
        self,
365
366
        input_features: torch.Tensor,
        audio_lens: Optional[torch.Tensor] = None,
367
    ):
368
        expected_seq_length = self.max_context_length
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

387
388
389
        attention_mask = self.get_attention_mask_by_audio_len(
            audio_lens, hidden_states)

390
391
392
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
393
                attention_mask,
394
395
396
397
398
399
400
401
402
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


403
404
405
406
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
    dummy_inputs=UltravoxDummyInputsBuilder)
407
408
409
410
411
412
413
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
    }

414
415
416
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})

417
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
418
        super().__init__()
419
420
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
421
422
423
424
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

425
426
        self.secondary_weights = []
        self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
427
        if config.audio_model_id is not None:
428
429
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
430
431
432
433
434
435
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
                ))
436
437
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
438
            vllm_config=vllm_config,
439
440
441
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
442
        if config.text_model_id is not None:
443
444
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
445
446
447
448
            self.secondary_weights.append(
                DefaultModelLoader.Source(model_or_path=config.text_model_id,
                                          revision=None,
                                          prefix="language_model."))
449

450
451
452
453
454
455
456
457
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
458
        return get_sampler()
459

460
461
462
463
464
465
466
467
468
469
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

470
    def _audio_features_to_embeddings(
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
            self, input_features: torch.Tensor,
            audio_lens: torch.Tensor) -> torch.Tensor:
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
            batch_features = self.audio_tower(audio_features[start:end],
                                              audio_lens[start:end])
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
            batch_embeddings = self.multi_modal_projector(batch_features)
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
491
492
493
494
495
496
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
497
498
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
499
500
501
502
503
504
505
506
507
508

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")

            return UltravoxAudioFeatureInputs(type="audio_features",
509
510
511
                                              data=audio_features,
                                              lens=audio_lens,
                                              token_len=audio_token_len)
512
513

        if audio_embeds is not None:
514
            if not isinstance(audio_embeds, (torch.Tensor, list)):
515
516
517
518
519
520
521
522
523
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
524
            self, audio_input: UltravoxAudioInputs) -> NestedTensors:
525
526
527
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

        if isinstance(audio_input['lens'], list):
            # [B1, B2] -> [B1+B2]
            audio_lens = torch.cat(audio_input['lens'])
            audio_token_len = torch.cat(audio_input['token_len'])
        else:
            audio_lens = flatten_bn(audio_input['lens'])
            audio_token_len = flatten_bn(audio_input['token_len'])

        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens)

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
546

547
548
549
550
551
552
553
554
555
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
            embeddings.shape[0], -1)
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

        return flattened_embeddings
556

557
558
559
    def get_multimodal_embeddings(
        self, **kwargs
    ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return None
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:

574
575
            # TODO(ywang96): remove this block after v0 is deprecated.
            if not envs.VLLM_USE_V1:
576
                attn_metadata = get_forward_context().attn_metadata
577
578
579
580
581
582
583
                merge_multimodal_embeddings_from_map(
                    inputs_embeds, multimodal_embeddings,
                    attn_metadata.multi_modal_placeholder_index_maps["audio"])
            else:
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, multimodal_embeddings,
                    _AUDIO_PLACEHOLDER_TOKEN)
584
585
586
587
588
589
590
        return inputs_embeds

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[torch.Tensor] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
591
                **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
592
593
594
595
596
597
598
599
600
601
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
602
603
604
605
606
            audio_features: A batch of audio input chunks [B, N, 80, M].
            audio_lens: Length of audio frames for each audio chunk [B].
            audio_token_len: Length of audio tokens for each audio chunk [B'].
                Note: batch dim is different from batch dim in audio chunks.

607
        """
608

609
        if intermediate_tensors is not None:
610
            inputs_embeds = None
611
612
613
614
615
616
617

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

            inputs_embeds = self.get_input_embeddings(input_ids,
618
                                                      multimodal_embeddings)
619
620
621
622
623
624
            input_ids = None

        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
625
626
627
628
629
630
631
632
633
634
635
636
637
638
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

639
640
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
641
642
643

        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["audio_tower."])
644
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672


def pad_and_concat_to_dim3(
    features: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]]
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
    # Pad and oncatenate:
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
    return torch.cat(features)