ultravox.py 28.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
6

7
import copy
8
import inspect
9
from collections.abc import Iterable, Mapping, Sequence
10
from types import SimpleNamespace
11
from typing import Annotated, Any, Literal, TypeAlias
12
13
14
15

import torch
from torch import nn
from torch.nn import functional as F
16
from transformers import BatchFeature, ProcessorMixin
17
from transformers.modeling_utils import ModuleUtilsMixin
18
from transformers.models.whisper import WhisperFeatureExtractor
19
20
21
22
from transformers.models.whisper.modeling_whisper import (
    WhisperEncoder,
    WhisperEncoderLayer,
)
23

24
from vllm.config import VllmConfig
25
from vllm.config.multimodal import BaseDummyOptions
26
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
27
from vllm.model_executor.layers.layernorm import RMSNorm
28
from vllm.model_executor.model_loader import DefaultModelLoader
29
from vllm.model_executor.models.module_mapping import MultiModelKeys
30
from vllm.multimodal import MULTIMODAL_REGISTRY
31
32
33
34
35
36
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
    NestedTensors,
)
37
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
38
from vllm.multimodal.processing import (
39
    BaseDummyInputsBuilder,
40
41
42
43
44
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
45
from vllm.renderers import TokenizeParams
46
from vllm.sequence import IntermediateTensors
47
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
48
from vllm.utils.tensor_schema import TensorSchema, TensorShape
49

50
51
52
53
54
55
56
57
58
59
60
61
62
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
63

64
_AUDIO_PLACEHOLDER_OVERRIDE = "<|audio|>"
65
_MAX_ENCODER_BATCH_SIZE = 16
66
67


68
class UltravoxAudioFeatureInputs(TensorSchema):
69
    """
70
71
72
73
74
    Dimensions:
    - b: batch size
    - n: number of chunks
    - t: Time frames (M)
    - nmb: Number of mel bins
75
    """
76

77
    type: Literal["audio_features"]
78
    data: Annotated[
79
        torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
80
        TensorShape("bn", "nmb", "t"),
81
    ]
82
83
84
85
86
87
88
89
    lens: Annotated[torch.Tensor, TensorShape("bn")]
    """
    Length of the audio frames per chunk. Used for attention mask in WhisperEncoder.
    """
    token_len: Annotated[torch.Tensor, TensorShape("bn")]
    """Length of the audio tokens per chunk. Used for flattening the audio features."""
    num_chunks: Annotated[torch.Tensor, TensorShape("n")]
    """Number of chunks per audio. Used for flattening the audio features."""
90
91
92


class UltravoxAudioEmbeddingInputs(TensorSchema):
93
    """
94
95
96
97
98
    Dimensions:
    - b: batch size
    - na: number of audios
    - afs: audio feature size
    - hs: hidden size
99
    """
100

101
    type: Literal["audio_embeds"]
102
    data: Annotated[
103
        torch.Tensor | list[torch.Tensor], TensorShape("b", "na", "afs", "hs")
104
    ]
105
106


107
108
109
UltravoxAudioInputs: TypeAlias = (
    UltravoxAudioFeatureInputs | UltravoxAudioEmbeddingInputs
)
110
111


112
class UltravoxProcessingInfo(BaseProcessingInfo):
113
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
114
        config = self.ctx.model_config.hf_config
115
        hf_processor = self.ctx.get_hf_processor(**kwargs)
116
117
118

        # NOTE: Ultravox processing definition uses '<|eot_id|>' as the
        # placeholder that will cause confusion with the actual end of turn
119
        # token, thus we override placeholder with a reserved token.
120
        hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
121
122
        hf_processor.audio_replacement_token_id = config.audio_token_index

123
        return hf_processor
124

125
    def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
126
        hf_processor = self.get_hf_processor(**kwargs)
127
128

        # Changed in https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/commit/9a3c571b8fdaf1e66dd3ea61bbcb6db5c70a438e
129
        audio_processor = hf_processor.audio_processor  # type: ignore
130
131
132
        if isinstance(audio_processor, WhisperFeatureExtractor):
            return audio_processor

133
134
135
136
        feature_extractor = audio_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

137
138
139
    def get_default_tok_params(self) -> TokenizeParams:
        return super().get_default_tok_params().with_kwargs(add_special_tokens=False)

140
141
142
143
144
145
146
147
148
    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            target_channels=self.get_target_channels(),
            expected_hidden_size=self._get_expected_hidden_size(),
        )

149
150
151
152
    def get_target_channels(self) -> int:
        """Return target audio channels for Ultravox models (mono)."""
        return 1

153
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
154
        return {"audio": None}
155

156

157
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]):
158
159
160
161
162
163
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)

        return "<|audio|>" * num_audios

    def get_dummy_mm_data(
164
        self,
165
166
        seq_len: int,
        mm_counts: Mapping[str, int],
167
        mm_options: Mapping[str, BaseDummyOptions],
168
    ) -> MultiModalDataDict:
169
        feature_extractor = self.info.get_feature_extractor()
170
171

        sampling_rate = feature_extractor.sampling_rate
172
173
174
        audio_len = (
            feature_extractor.chunk_length * sampling_rate * _MAX_ENCODER_BATCH_SIZE
        )
175
176
        num_audios = mm_counts.get("audio", 0)

177
        audio_overrides = mm_options.get("audio")
178

179
        return {
180
            "audio": self._get_dummy_audios(
181
182
183
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
184
            )
185
186
187
        }


188
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
189
    def _call_hf_processor(
190
191
        self,
        prompt: str,
192
193
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
194
        tok_kwargs: Mapping[str, object],
195
    ) -> BatchFeature:
196
        # Text-only input not supported in composite processor
197
        if not mm_data.get("audios", []):
198
            prompt_ids = self.info.get_tokenizer().encode(
199
200
                prompt, add_special_tokens=False
            )
201
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
202
203
204
205
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])
206
        assert isinstance(audios, list)
207

208
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
209
210
        mm_kwargs = dict(
            **mm_kwargs,
211
            sampling_rate=feature_extractor.sampling_rate,
212
            include_audio_num_chunks=True,
213
214
        )

215
        item_processor_data = dict(**mm_data, audios=audios)
216

217
        # some tokenizer kwargs are incompatible with UltravoxProcessor
218
        tok_kwargs.pop("add_special_tokens", None)
219
220
221
        tok_kwargs.pop("padding", None)
        tok_kwargs.pop("truncation", None)

222
223
224
225
        output = super()._call_hf_processor(
            prompt=prompt,
            mm_data=item_processor_data,
            mm_kwargs=mm_kwargs,
226
            tok_kwargs=tok_kwargs,
227
        )
228
        output["audio_features"] = output.pop("audio_values")
229
230

        return output
231

232
233
234
235
236
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
237
        num_chunks = hf_inputs.get("audio_num_chunks", torch.zeros(0))
238
        return dict(
239
240
241
            # to handle longer than 30s audio, each audio might be split
            # into multiple chunks as such, their batch dimension can be
            # higher than the number of audio samples
242
243
244
            audio_features=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_token_len=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
            audio_lens=MultiModalFieldConfig.flat_from_sizes("audio", num_chunks),
245
246
            # num_chunks can convert audio_chunked to audio batch dimension
            audio_num_chunks=MultiModalFieldConfig.batched("audio"),
247
248
249
            audio_embeds=MultiModalFieldConfig.batched("audio"),
        )

250
    def _get_prompt_updates(
251
252
        self,
        mm_items: MultiModalDataItems,
253
        hf_processor_mm_kwargs: Mapping[str, Any],
254
        out_mm_kwargs: MultiModalKwargsItems,
255
    ) -> Sequence[PromptUpdate]:
256
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
257

258
259
260
261
262
        replacement_id = hf_processor.audio_replacement_token_id  # type: ignore

        # Each audio can be split into multiple chunks.
        # chunks_start_idx[i] indicates the start index of the chunks
        # belonging to the i-th audio.
263
264
        out_mm_data = out_mm_kwargs.get_data()
        num_chunks = out_mm_data.get("audio_num_chunks", torch.zeros(0))
265
266
267
        chunks_start_idx: torch.Tensor = torch.cumsum(
            num_chunks, dim=0, dtype=torch.int32
        )
268
        chunks_start_idx = torch.cat(
269
270
            [torch.tensor([0], dtype=torch.int32), chunks_start_idx]
        )
271
272

        def get_replacement_ultravox(item_idx: int):
273
274
            start = chunks_start_idx[item_idx]
            end = chunks_start_idx[item_idx + 1]
275
            audio_token_len = out_mm_data["audio_token_len"][start:end].sum()
276
            return [replacement_id] * int(audio_token_len)  # type: ignore
277
278
279
280

        return [
            PromptReplacement(
                modality="audio",
281
                target="<|audio|>",
282
283
284
                replacement=get_replacement_ultravox,
            )
        ]
285
286
287
288
289
290
291
292
293
294
295
296
297
298


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
299
        T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
300
301
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
302
303
304
        audio_embeds = audio_embeds.view(
            B, T // self.stack_factor, C * self.stack_factor
        )
305
306
307
        return audio_embeds


308
class UltravoxFeedForwardProjector(nn.Module):
309
310
311
312
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
313
314
315
316
        dim_in = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim_in)
        self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
        dim_mid = self.hidden_dim
317
318

        if config.projector_act == "swiglu":
319
            self.act = MulAndSilu()
320
            dim_mid = dim_mid // 2
321
322
323
        else:
            self.act = get_act_fn(config.projector_act)

324
        dim_out = config.text_config.hidden_size
325
326
327
328
329
330
331
332
333
334
        self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)

        # Ultravox v0.4.1 and below use layer_norm after the second linear layer
        # while v0.5.0 and above uses layer_norm after the first linear layer.
        if config.projector_ln_mid:
            self.ln_mid: nn.Module = RMSNorm(dim_mid)
            self.ln_post = nn.Identity()
        else:
            self.ln_mid = nn.Identity()
            self.ln_post = RMSNorm(dim_out)
335

336
337
338
    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
339
340
341
342
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
343
        hidden_states = self.ln_mid(hidden_states)
344
345
346
347
348
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.config = SimpleNamespace(is_decoder=False)

        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim_in = config.audio_config.hidden_size * config.stack_factor

        projector_audio_config = copy.deepcopy(config.audio_config)

        self.ln_pre = RMSNorm(dim_in)
        self.linear_in = nn.Linear(dim_in, projector_audio_config.d_model)

        self.embed_positions = nn.Embedding(
            projector_audio_config.max_source_positions,
            projector_audio_config.d_model,
        )

        self.layers = nn.ModuleList(
            [
                WhisperEncoderLayer(projector_audio_config)
                for _ in range(config.num_projector_layers)
            ]
        )

        self.ln_post = RMSNorm(projector_audio_config.d_model)
        self.linear_out = nn.Linear(
            projector_audio_config.d_model, config.text_config.hidden_size
        )

    def forward(
        self, audio_features: torch.Tensor, audio_token_len: torch.Tensor
    ) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)

        max_len_stacked = audio_features.shape[1]
        attention_mask = torch.arange(max_len_stacked, device=audio_features.device)[
            None, :
        ].lt(audio_token_len[:, None])
        extended_attention_mask = self.get_extended_attention_mask(
            attention_mask, attention_mask.shape, audio_features.dtype
        )

        hidden_states = self.ln_pre(audio_features)
        hidden_states = self.linear_in(hidden_states)

        positions = self.embed_positions(
            torch.arange(hidden_states.size(1), device=hidden_states.device)
        )
        hidden_states = hidden_states + positions

400
401
402
403
404
405
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

406
407
408
409
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=extended_attention_mask,
410
                **kwargs,
411
412
413
414
415
416
417
418
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.linear_out(hidden_states)
        return hidden_states


419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

438
439
440
441
442
443
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.config.is_decoder = False

    @property
    def max_context_length(self):
444
445
446
447
448
        return (
            self.config.max_source_positions
            * self.conv1.stride[0]
            * self.conv2.stride[0]
        )
449

450
    def get_attention_mask_by_audio_len(
451
        self, audio_lens: torch.Tensor | None, hidden_states: torch.Tensor
452
    ):
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        """
        Create attention mask based on audio lengths to mask out padding tokens
        For each sample in batch:
        - Convert raw audio length to feature length after convolutions
        - Create bool mask: True for valid positions and False for padding
        - Convert to attention mask format expected by transformer layers
        (1.0 for positions to attend to, large negative for positions to ignore)
        This masking ensures consistent behavior between training and inference
        by preventing the model from attending to padding tokens in both cases
        """
        if audio_lens is None:
            return None

        audio_feature_len = self._get_feat_extract_output_lengths(audio_lens)
        max_seq_len = hidden_states.shape[1]
468
469
470
        attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
            None, :
        ].lt(audio_feature_len.view(-1, 1))
471
472
473
474
475
476
477
        attention_mask = self.get_extended_attention_mask(
            attention_mask,
            None,
            dtype=hidden_states.dtype,
        )
        return attention_mask

478
479
    def forward(
        self,
480
        input_features: torch.Tensor,
481
        audio_lens: torch.Tensor | None = None,
482
    ):
483
        expected_seq_length = self.max_context_length
484
485
486
487
488
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
489
490
                f"features to {expected_seq_length}."
            )
491
492
493
494
495

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
496
        embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
497
498

        hidden_states = inputs_embeds + embed_pos
499
500
501
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )
502

503
        attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
504

505
506
507
508
509
510
        # Backward compatibility for Transformers v4 where layer_head_mask
        # was a required argument for WhisperEncoderLayer.forward
        kwargs = {}
        if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
            kwargs["layer_head_mask"] = None

511
512
513
        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
514
                attention_mask,
515
                **kwargs,
516
517
518
519
520
521
522
523
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


524
525
526
@MULTIMODAL_REGISTRY.register_processor(
    UltravoxMultiModalProcessor,
    info=UltravoxProcessingInfo,
527
528
    dummy_inputs=UltravoxDummyInputsBuilder,
)
529
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
530
531
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
532
        "gate_up_proj": ["gate_proj", "up_proj"],
533
534
    }

535
    hf_to_vllm_mapper = WeightsMapper(
536
537
        orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}
    )
538

539
    @classmethod
540
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
541
542
543
544
545
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

546
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
547
        super().__init__()
548
        config: UltravoxConfig = vllm_config.model_config.hf_config
549
        multimodal_config = vllm_config.model_config.multimodal_config
550
551
552
553
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

554
555
556
557
558
        self.configure_mm_token_handling(
            self.config.vocab_size,
            [self.config.audio_token_index],
        )

559
        self.secondary_weights = []
560
        if config.audio_model_id is not None:
561
562
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
563
564
565
566
567
            self.secondary_weights.append(
                DefaultModelLoader.Source(
                    model_or_path=config.audio_model_id,
                    revision=None,
                    prefix="audio_tower.",
568
569
                )
            )
570
        if config.text_model_id is not None:
571
572
            # this prefix is not for initialization, but for loading weights
            # note the trailing dot
573
            self.secondary_weights.append(
574
575
576
577
578
579
                DefaultModelLoader.Source(
                    model_or_path=config.text_model_id,
                    revision=None,
                    prefix="language_model.",
                )
            )
580

581
582
583
584
585
586
587
588
589
590
591
592
593
594
        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
            if config.num_projector_layers > 0:
                self.multi_modal_projector = UltravoxTransformerProjector(config)
            else:
                self.multi_modal_projector = UltravoxFeedForwardProjector(config)

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.wrapped_model_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

595
        self.make_empty_intermediate_tensors = (
596
597
            self.language_model.make_empty_intermediate_tensors
        )
598

599
600
601
602
603
604
605
606
607
608
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model.",
            connector="multi_modal_projector.",
            tower_model="audio_tower.",
        )

609
    def _audio_features_to_embeddings(
610
611
612
613
        self,
        input_features: torch.Tensor,
        audio_lens: torch.Tensor,
        audio_token_len: torch.Tensor,
614
    ) -> torch.Tensor:
615
616
617
618
619
620
621
622
        audio_features = input_features.to(self.audio_tower.dtype)
        batch_size = audio_features.size(0)
        audio_embeddings = []

        # Process audio features in batches to keep memory usage predictable
        for start in range(0, batch_size, _MAX_ENCODER_BATCH_SIZE):
            end = min(start + _MAX_ENCODER_BATCH_SIZE, batch_size)
            # Process through audio tower
623
624
625
            batch_features = self.audio_tower(
                audio_features[start:end], audio_lens[start:end]
            )
626
627
628
            batch_features = batch_features.to(self.audio_tower.dtype)

            # Process through projector
629
630
631
            batch_embeddings = self.multi_modal_projector(
                batch_features, audio_token_len[start:end]
            )
632
633
634
635
            audio_embeddings.append(batch_embeddings)

        # Concatenate results
        audio_embeddings = torch.cat(audio_embeddings, dim=0)
636
637
638
        return audio_embeddings

    def _parse_and_validate_audio_input(
639
        self, **kwargs: object
640
    ) -> UltravoxAudioInputs | None:
641
642
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)
643
644
        audio_lens = kwargs.pop("audio_lens", None)
        audio_token_len = kwargs.pop("audio_token_len", None)
645
        audio_num_chunks = kwargs.pop("audio_num_chunks", None)
646
647
648
649
650

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
651
652
653
654
655
            return UltravoxAudioFeatureInputs(
                type="audio_features",
                data=audio_features,
                lens=audio_lens,
                token_len=audio_token_len,
656
                num_chunks=audio_num_chunks,
657
            )
658
659

        if audio_embeds is not None:
660
            return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds)
661
662
663
664

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
665
666
        self,
        audio_input: UltravoxAudioInputs,
667
    ) -> NestedTensors | tuple[torch.Tensor, ...]:
668
669
670
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

671
672
673
674
        # Pad and concatenate audio features
        # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
        audio_features = pad_and_concat_to_dim3(audio_input["data"])

675
676
        audio_lens = audio_input["lens"]
        audio_token_len = audio_input["token_len"]
677

678
679
680
        embeddings = self._audio_features_to_embeddings(
            audio_features, audio_lens, audio_token_len
        )
681
682
683
684

        # We should flatten and concatenate embeddings based on token lengths
        # For example, with token_len = [4, 2, 3], flattened_embeddings will be
        # concat(embeddings[0][:4], embeddings[1][:2], embeddings[2][:3])
685

686
687
688
        # Create a mask of valid indices based on token lengths
        max_len = embeddings.shape[1]
        indices = torch.arange(max_len, device=embeddings.device).expand(
689
690
            embeddings.shape[0], -1
        )
691
692
693
694
        mask = indices < audio_token_len[:, None]
        # Apply mask and flatten
        flattened_embeddings = embeddings[mask]

695
696
        # Return one tensor per input audio
        embed_lens = [
697
698
            chunk_lens.sum().item()
            for chunk_lens in audio_token_len.split(audio_input["num_chunks"].tolist())
699
700
        ]
        return flattened_embeddings.split(embed_lens)
701

702
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
703
704
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
705
            return []
706
707
708
        audio_embeddings = self._process_audio_input(audio_input)
        return audio_embeddings

709
    def embed_input_ids(
710
711
        self,
        input_ids: torch.Tensor,
712
        multimodal_embeddings: MultiModalEmbeddings | None = None,
713
        *,
714
        is_multimodal: torch.Tensor | None = None,
715
    ) -> torch.Tensor:
716
717
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
718
            return super().embed_input_ids(input_ids)
719

720
        return super().embed_input_ids(
721
722
723
724
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
725

726
727
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
728
        input_ids: torch.Tensor | None,
729
        positions: torch.Tensor,
730
731
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
732
        **kwargs,
733
    ) -> torch.Tensor | IntermediateTensors:
734
735
736
737
738
739
740
741
742
743
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
744
745
746
747
748
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
749

750
        """
751

752
        if intermediate_tensors is not None:
753
            inputs_embeds = None
754

755
756
757
758
        language_model = self.language_model
        if hasattr(language_model, "language_model"):
            language_model = language_model.language_model

759
760
761
        hidden_states = language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
762
763
        return hidden_states

764
765
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states)
766

767
768
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, ignore_unexpected_prefixes=["audio_tower."])
769
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
770
771
772


def pad_and_concat_to_dim3(
773
    features: torch.Tensor | list[torch.Tensor] | list[list[torch.Tensor]],
774
775
776
777
778
779
780
781
782
783
784
785
786
) -> torch.Tensor:
    """
    Pad and concatenate a list of tensors.

    output:
        Tensor of shape [B, C, M] where M is the maximum length of the input
        tensors, B is the sum of the batch sizes of the input tensors.
        C must be the same for all input tensors.
    """
    if isinstance(features, torch.Tensor):
        if features.ndim > 3:
            # Flatten [B, N, 80, M] -> [B * N, 80, M]
            features = flatten_bn(features)
787

788
789
790
791
792
793
794
        return features

    features = [pad_and_concat_to_dim3(f) for f in features]

    max_len = max(f.shape[-1] for f in features)
    # Ensure all features have dim=3
    features = [f.view(-1, *f.shape[-2:]) for f in features]
795
    # Pad and concatenate:
796
797
    # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
    features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
zhuwenwen's avatar
zhuwenwen committed
798
    return torch.cat(features)