ultravox.py 17.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""

import itertools
import math
from array import array
from functools import lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union, cast)

import librosa
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights,
                                              init_vllm_registered_model,
                                              merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25

logger = init_logger(__name__)


class UltravoxAudioFeatureInputs(TypedDict):
    type: Literal["audio_features"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """Shape: `(batch_size, 80, M)"""


class UltravoxAudioEmbeddingInputs(TypedDict):
    type: Literal["audio_embeds"]
    data: torch.Tensor


UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
                            UltravoxAudioEmbeddingInputs]


@lru_cache
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
    return WhisperFeatureExtractor.from_pretrained(model_id)


def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
    return cached_feature_extractor(
        ctx.get_hf_config(UltravoxConfig).audio_model_id)


def get_ultravox_max_audio_tokens(ctx: InputContext):
    feature_extractor = whisper_feature_extractor(ctx)
    return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)


def dummy_data_for_ultravox(
    ctx: InputContext,
    seq_len: int,
    mm_counts: Mapping[str, int],
):
    feature_extractor = whisper_feature_extractor(ctx)

    audio_count = mm_counts["audio"]

    audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
        _AUDIO_PLACEHOLDER_TOKEN
    ]) * get_ultravox_max_audio_tokens(ctx) * audio_count
    other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                            [0]) * (seq_len - len(audio_token_ids))

    audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
    mm_dict = {
        "audio":
        audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
    }

    return (SequenceData(audio_token_ids + other_token_ids), mm_dict)


def input_mapper_for_ultravox(ctx: InputContext, data: object):
    if isinstance(data, tuple):
        (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
        feature_extractor = whisper_feature_extractor(ctx)

        if sr != feature_extractor.sampling_rate:
            audio = librosa.resample(audio,
                                     orig_sr=sr,
                                     target_sr=feature_extractor.sampling_rate)
            sr = feature_extractor.sampling_rate

        minimum_audio_length = feature_extractor.n_fft // 2 + 1
        if len(audio) < minimum_audio_length:
            # Not enough audio; pad it.
            audio = np.pad(audio, (0, minimum_audio_length - len(audio)))

        return MultiModalInputs({
            "audio_features":
            feature_extractor(audio,
                              sampling_rate=sr,
                              padding="longest",
                              return_tensors="pt")["input_features"]
        })

    raise NotImplementedError(f"Unsupported data type: {type(data)}")


def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "audio" not in multi_modal_data:
        return llm_inputs

    feature_extractor = whisper_feature_extractor(ctx)
    audio_data, sample_rate = multi_modal_data["audio"]

    audio_length = audio_data.shape[0]
    if sample_rate != feature_extractor.sampling_rate:
        # Account for resampling.
        adjustment = feature_extractor.sampling_rate / sample_rate
        audio_length = math.ceil(adjustment * audio_length)

    feature_extractor_output_length = math.ceil(
        (audio_length -
         (feature_extractor.hop_length - 1)) / feature_extractor.hop_length)

    uv_config = ctx.get_hf_config(UltravoxConfig)
    audio_num_tokens = min(
        max(
            1,
            math.ceil(feature_extractor_output_length /
                      (uv_config.stack_factor * 2))),
        get_ultravox_max_audio_tokens(ctx))
    tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)

    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
        tokenizer,
        llm_inputs.get("prompt"),
        llm_inputs["prompt_token_ids"],
        placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
        repeat_count=audio_num_tokens,
    )

    # NOTE: Create a defensive copy of the original inputs
    return LLMInputs(prompt_token_ids=new_token_ids,
                     prompt=new_prompt,
                     multi_modal_data=multi_modal_data)


class StackAudioFrames(nn.Module):
    """
    Stack the audio embedding frames to reduce the sequence length by a factor
    of `stack_factor`.
    """

    def __init__(self, stack_factor: int = 8):
        super().__init__()
        self.stack_factor = stack_factor

    def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
        B, T, C = audio_embeds.shape
        T_pad = (T + self.stack_factor -
                 1) // self.stack_factor * self.stack_factor
        audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
        B, T, C = audio_embeds.shape
        audio_embeds = audio_embeds.view(B, T // self.stack_factor,
                                         C * self.stack_factor)
        return audio_embeds


class FlippedSiluAndMul(SiluAndMul):
    """Ultravox is trained with SwiGLU with flipped halves."""

    def forward(self, x: torch.Tensor):
        a, b = x.chunk(2, dim=-1)
        flipped = torch.cat((b, a), dim=-1)
        return super().forward(flipped)


class UltravoxProjector(nn.Module):

    def __init__(self, config: UltravoxConfig):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self._pad_and_stack = StackAudioFrames(config.stack_factor)
        dim = config.audio_config.hidden_size * config.stack_factor
        self.ln_pre = RMSNorm(dim)
        self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
        dim = self.hidden_dim

        if config.projector_act == "swiglu":
            self.act = FlippedSiluAndMul()
            dim = dim // 2
        else:
            self.act = get_act_fn(config.projector_act)

        self.linear_2 = nn.Linear(dim,
                                  config.text_config.hidden_size,
                                  bias=False)
        self.ln_post = RMSNorm(config.text_config.hidden_size)

    def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
        audio_features = self._pad_and_stack(audio_features)
        audio_features = self.ln_pre(audio_features)
        hidden_states = self.linear_1(audio_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        hidden_states = self.ln_post(hidden_states)
        return hidden_states


class ModifiedWhisperEncoder(WhisperEncoder):
    """
    Encoder portion of OpenAI's Whisper model.

    This implementation is a slightly modified version of HF Transformers'
    Whisper Encoder, with only a few fixes:
    1. base_model_prefix updated to allow for doing `.from_pretrained`
       directly on the encoder
    2. allow less than 30 second of audio padding to be passed in:
        - relaxed ValueError check for `input_features` length to be less
           than or equal to `expected_seq_length` instead of strictly equal
        - embed_pos is now sliced to match the length of `inputs_embeds`

    Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
    See commentary: https://github.com/huggingface/transformers/issues/25744
    """

    base_model_prefix = "model.encoder"

    def forward(
        self,
        input_features,
    ):
        expected_seq_length = (self.config.max_source_positions *
                               self.conv1.stride[0] * self.conv2.stride[0])
        if input_features.shape[-1] > expected_seq_length:
            raise ValueError(
                f"Whisper expects the mel input features to be of length "
                f"{expected_seq_length} or less, but found "
                f"{input_features.shape[-1]}. Make sure to pad the input mel "
                f"features to {expected_seq_length}.")

        inputs_embeds = nn.functional.gelu(self.conv1(input_features))
        inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))

        inputs_embeds = inputs_embeds.permute(0, 2, 1)
        embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)]

        hidden_states = inputs_embeds + embed_pos
        hidden_states = nn.functional.dropout(hidden_states,
                                              p=self.dropout,
                                              training=self.training)

        for encoder_layer in self.layers:
            layer_outputs = encoder_layer(
                hidden_states,
                None,
                layer_head_mask=None,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.layer_norm(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
    "audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal):

    def __init__(self,
                 config: UltravoxConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional["QuantizationConfig"] = None):
        super().__init__()
        self.config = config
        self.multi_modal_config = multimodal_config
        assert self.multi_modal_config

        if config.audio_model_id is not None:
            self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
                config.audio_model_id)
        else:
            self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
        self.multi_modal_projector = UltravoxProjector(config)
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)

    def _audio_features_to_embeddings(
            self, input_features: torch.Tensor) -> torch.Tensor:
        audio_input = input_features.to(self.audio_tower.dtype)
        audio_features = self.audio_tower(audio_input)
        audio_features = audio_features.to(self.audio_tower.dtype)
        audio_embeddings = self.multi_modal_projector(audio_features)
        return audio_embeddings

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[UltravoxAudioInputs]:
        audio_features = kwargs.pop("audio_features", None)
        audio_embeds = kwargs.pop("audio_embeds", None)

        if audio_features is None and audio_embeds is None:
            return None

        if audio_features is not None:
            if not isinstance(audio_features, (torch.Tensor, list)):
                raise ValueError("Incorrect type of audio features. "
                                 f"Got type: {type(audio_features)}")

336
337
338
339
340
341
            # Remove the N dimension until multiple audios are supported.
            if isinstance(audio_features, torch.Tensor):
                audio_features = audio_features.squeeze(1)
            else:
                audio_features = [t.squeeze(0) for t in audio_features]

342
343
344
345
346
347
348
349
            return UltravoxAudioFeatureInputs(type="audio_features",
                                              data=audio_features)

        if audio_embeds is not None:
            if not isinstance(audio_embeds, torch.Tensor):
                raise ValueError("Incorrect type of audio embeds. "
                                 f"Got type: {type(audio_embeds)}")

350
351
352
            # Remove the N dimension until multiple audios are supported.
            audio_embeds = audio_embeds.squeeze(1)

353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
            return UltravoxAudioEmbeddingInputs(type="audio_embeds",
                                                data=audio_embeds)

        raise AssertionError("This line should be unreachable.")

    def _process_audio_input(
        self, audio_input: UltravoxAudioInputs
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
        if audio_input["type"] == "audio_embeds":
            return audio_input["data"]

        audio_features = audio_input["data"]
        if isinstance(audio_features, list):
            # TODO: Batch these through the encoder/projector instead of
            # serializing them.
            return [
                self._audio_features_to_embeddings(
                    features.unsqueeze(0)).squeeze(0)
                for features in audio_features
            ]
        else:
            return self._audio_features_to_embeddings(audio_features)

    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[torch.Tensor],
                **kwargs) -> SamplerOutput:
        """Run forward pass for Ultravox

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted audio embeddings. The to-be-inserted
        audio has a size that is essentially 6.25 tokens per second of audio.

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_features: A batch of audio inputs, [1, 80, M].
        """
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is not None:
            audio_embeddings = self._process_audio_input(audio_input)
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)

            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, audio_embeddings,
                _AUDIO_PLACEHOLDER_TOKEN)
            input_ids = None
        else:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        return self.language_model.sample(logits, sampling_metadata)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # prepare weight iterators for components
        projector_weights, llm_weights = itertools.tee(weights, 2)

        # load projector weights
        projector_weights = filter_weights(projector_weights,
                                           "multi_modal_projector")
        projector_params_dict = dict(
            self.multi_modal_projector.named_parameters())
        for name, loaded_weight in projector_weights:
            param = projector_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)