"tests/planner/scaling/disagg_planner_throughput.yaml" did not exist on "157714aa9dd3c651afe5300c679a432cf1c96ba8"
musicflamingo.py 15.4 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2026 The vLLM team.
# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19

20
21
22
from collections.abc import Callable, Mapping, Sequence
from math import pi
from typing import Annotated, Any, Optional, TypeAlias
23

24
25
26
27
28
29
30
import torch
from torch import Tensor, broadcast_tensors, nn
from transformers import BatchFeature
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.models.musicflamingo import (
    MusicFlamingoConfig,
    MusicFlamingoProcessor,
31
32
)

33
34
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
35
from vllm.multimodal import MULTIMODAL_REGISTRY
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    DictEmbeddingItems,
    ModalityData,
    ModalityDataItems,
    MultiModalDataItems,
    MultiModalDataParser,
)
from vllm.multimodal.processing import (
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
from vllm.utils.tensor_schema import TensorShape
54
55
56

from .audioflamingo3 import (
    AudioFlamingo3DummyInputsBuilder,
57
58
59
    AudioFlamingo3EmbeddingInputs,
    AudioFlamingo3Encoder,
    AudioFlamingo3FeatureInputs,
60
    AudioFlamingo3ForConditionalGeneration,
61
    AudioFlamingo3MultiModalDataParser,
62
    AudioFlamingo3MultiModalProcessor,
63
64
65
66
    AudioFlamingo3MultiModalProjector,
    AudioFlamingo3ProcessingInfo,
    _audioflamingo3_field_config,
    _count_audio_tokens_from_mask,
67
68
)

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def rotate_half(x):
    x = x.reshape(*x.shape[:-1], -1, 2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)


def apply_rotary_time_emb(hidden_states, cos, sin):
    original_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float64)
    cos = cos.to(hidden_states)
    sin = sin.to(hidden_states)
    rot_dim = cos.shape[-1]
    if rot_dim > hidden_states.shape[-1]:
        raise ValueError(
            f"feature dimension {hidden_states.shape[-1]} is not of "
            f"sufficient size to rotate in all the positions {rot_dim}"
87
88
        )

89
90
91
92
    rotated = hidden_states[..., :rot_dim]
    passthrough = hidden_states[..., rot_dim:]
    rotated = (rotated * cos) + (rotate_half(rotated) * sin)
    return torch.cat((rotated, passthrough), dim=-1).to(original_dtype)
93

94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
class MusicFlamingoRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor

    def __init__(self, config: MusicFlamingoConfig, device=None):
        super().__init__()
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_type = self.config.rope_parameters["rope_type"]
        rope_init_fn: Callable = self.compute_default_rope_parameters
        if self.rope_type != "default":
            rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
        inv_freq, self.attention_scaling = rope_init_fn(self.config, device)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
        position_angles = self._compute_position_angles(self.inv_freq)
        self.register_buffer("position_angles", position_angles, persistent=False)

    @staticmethod
    def compute_default_rope_parameters(
        config: MusicFlamingoConfig | None = None,
        device: Optional["torch.device"] = None,
        seq_len: int | None = None,
    ) -> tuple["torch.Tensor", float]:
        del seq_len
        base = config.rope_parameters["rope_theta"]
        dim = getattr(config, "head_dim", None) or (
            config.hidden_size // config.num_attention_heads
        )
        attention_factor = 1.0

        inv_freq = 1.0 / (
            base
            ** (
                torch.arange(0, dim, 2, dtype=torch.int64).to(
                    device=device,
                    dtype=torch.float,
                )
                / dim
            )
        )
        return inv_freq, attention_factor

    def _compute_position_angles(self, inv_freq):
        positions = torch.arange(
            int(self.max_seq_len_cached),
            device=inv_freq.device,
            dtype=inv_freq.dtype,
        )
        positions = positions / self.max_seq_len_cached * (2 * pi)
        position_angles = positions.unsqueeze(-1) * inv_freq
        position_angles = torch.repeat_interleave(position_angles, 2, dim=-1)
        return position_angles.to(dtype=inv_freq.dtype)

    @torch.no_grad()
    def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]:
        batch_positions = torch.arange(
            timestamps.shape[0],
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        batch_positions = batch_positions / self.max_seq_len_cached
        batch_freqs = batch_positions.unsqueeze(-1) * self.inv_freq
        batch_freqs = torch.repeat_interleave(batch_freqs, 2, dim=-1)

        batch_freqs = batch_freqs[:, None, :]
        time_freqs = self.position_angles[:seq_len][None, :, :]
        batch_freqs, time_freqs = broadcast_tensors(batch_freqs, time_freqs)
        freqs = torch.cat((batch_freqs, time_freqs), dim=-1)
        angle = (-timestamps * 2 * pi).to(freqs)
        freqs = freqs * angle.unsqueeze(-1)
        return freqs.cos(), freqs.sin()


class MusicFlamingoFeatureInputs(AudioFlamingo3FeatureInputs):
    rote_timestamps: Annotated[
        torch.Tensor,
        TensorShape(
            "num_chunks",
            "num_audio_time_steps",
            dynamic_dims={"num_audio_time_steps"},
        ),
    ]


MusicFlamingoEmbeddingInputs = AudioFlamingo3EmbeddingInputs

MusicFlamingoInputs: TypeAlias = (
    MusicFlamingoFeatureInputs | MusicFlamingoEmbeddingInputs
)


class MusicFlamingoEncoder(AudioFlamingo3Encoder):
    pass


class MusicFlamingoMultiModalProjector(AudioFlamingo3MultiModalProjector):
    pass


class MusicFlamingoProcessingInfo(AudioFlamingo3ProcessingInfo):
    def get_hf_config(self) -> MusicFlamingoConfig:
        return self.ctx.get_hf_config(MusicFlamingoConfig)

    def get_hf_processor(self, **kwargs: object) -> MusicFlamingoProcessor:
        return self.ctx.get_hf_processor(MusicFlamingoProcessor, **kwargs)

    def get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.get_feature_extractor()
        return MusicFlamingoMultiModalDataParser(
207
208
209
210
            target_sr=feature_extractor.sampling_rate,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

211
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
212
        return {"audio": 1}
213
214
215


class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_audios = mm_counts.get("audio", 0)
        hf_processor = self.info.get_hf_processor()
        return hf_processor.audio_token * num_audios

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions],
    ) -> MultiModalDataDict:
        hf_processor = self.info.get_hf_processor()
        feature_extractor = self.info.get_feature_extractor()
        sampling_rate = feature_extractor.sampling_rate
        audio_len = int(hf_processor.max_audio_len * sampling_rate)
        num_audios = mm_counts.get("audio", 0)
        audio_overrides = mm_options.get("audio")

        return {
            "audio": self._get_dummy_audios(
                length=audio_len,
                num_audios=num_audios,
                overrides=audio_overrides,
            )
        }


def _musicflamingo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
    fields = dict(_audioflamingo3_field_config(hf_inputs))
    chunk_counts = hf_inputs.get("chunk_counts")
    if chunk_counts is not None:
        fields["rote_timestamps"] = MultiModalFieldConfig.flat_from_sizes(
            "audio", chunk_counts, dim=0
        )
    else:
        fields["rote_timestamps"] = MultiModalFieldConfig.batched("audio")
    return fields


class MusicFlamingoMultiModalDataParser(AudioFlamingo3MultiModalDataParser):
    def _parse_audio_data(
        self,
        data: dict[str, torch.Tensor] | ModalityData[Any],
    ) -> ModalityDataItems[Any, Any] | None:
        if isinstance(data, dict):
            return DictEmbeddingItems(
                data,
                modality="audio",
                required_fields={"audio_embeds"},
                fields_factory=_musicflamingo_field_config,
            )
        return super()._parse_audio_data(data)


class MusicFlamingoMultiModalProcessor(AudioFlamingo3MultiModalProcessor):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: dict[str, object],
        mm_kwargs: Mapping[str, Any],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

        audio_data = mm_data.get("audio")
        if audio_data is None:
            return outputs

        audio_list = audio_data if isinstance(audio_data, list) else [audio_data]
        if len(audio_list) == 0:
            return outputs

        processor = self.info.get_hf_processor(**mm_kwargs)
        feature_extractor = processor.feature_extractor
        sampling_rate = feature_extractor.sampling_rate
        chunk_length = feature_extractor.chunk_length
        window_size = int(sampling_rate * chunk_length)
        max_windows = int(processor.max_audio_len // chunk_length)

        chunk_counts = []
        for audio in audio_list:
            n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
            n_win = max(1, (n_samples + window_size - 1) // window_size)
            chunk_counts.append(min(n_win, max_windows))
        outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)

        if "rote_timestamps" not in outputs:
            raise KeyError(
                "MusicFlamingoProcessor output must include `rote_timestamps`."
            )

        return outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return _musicflamingo_field_config(hf_inputs)

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()

        audio_token = processor.audio_token
        audio_token_id = vocab.get(audio_token, processor.audio_token_id)

        audio_bos_token = processor.audio_bos_token
        audio_bos_token_id = vocab.get(audio_bos_token, processor.audio_bos_token_id)

        audio_eos_token = processor.audio_eos_token
        audio_eos_token_id = vocab.get(audio_eos_token, processor.audio_eos_token_id)

        out_mm_data = out_mm_kwargs.get_data()
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
        chunk_counts = out_mm_data.get("chunk_counts")

        def get_replacement_musicflamingo(item_idx: int):
            if feature_attention_mask is not None:
                num_features = _count_audio_tokens_from_mask(
                    feature_attention_mask,
                    chunk_counts,
                    item_idx,
                )
            else:
                audio_embeds = out_mm_data["audio_embeds"][item_idx]
                num_features = audio_embeds.shape[0]

            if num_features == 0:
                raise ValueError("Audio is too short")

            full_tokens = [
                audio_bos_token_id,
                *([audio_token_id] * int(num_features)),
                audio_eos_token_id,
            ]

            return PromptUpdateDetails.select_token_id(
                full_tokens,
                embed_token_id=audio_token_id,
            )

        return [
            PromptReplacement(
                modality="audio",
                target=audio_token,
                replacement=get_replacement_musicflamingo,
            )
        ]
376
377
378


@MULTIMODAL_REGISTRY.register_processor(
379
    MusicFlamingoMultiModalProcessor,
380
381
382
383
    info=MusicFlamingoProcessingInfo,
    dummy_inputs=MusicFlamingoDummyInputsBuilder,
)
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    """vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)
        self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
        self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
        self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> MusicFlamingoInputs | None:
        rote_timestamps = kwargs.pop("rote_timestamps", None)
        audio_input = super()._parse_and_validate_audio_input(**kwargs)
        if audio_input is None or audio_input["type"] == "audio_embeds":
            return audio_input

        return MusicFlamingoFeatureInputs(
            type="audio_features",
            input_features=audio_input["input_features"],
            feature_attention_mask=audio_input["feature_attention_mask"],
            chunk_counts=audio_input["chunk_counts"],
            rote_timestamps=rote_timestamps,
        )

    def _process_audio_input(
        self, audio_input: MusicFlamingoInputs
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
        if audio_input["type"] == "audio_embeds":
            return super()._process_audio_input(audio_input)

        rote_timestamps = audio_input["rote_timestamps"]
        if rote_timestamps is None:
            raise ValueError(
                "MusicFlamingo audio feature inputs must include `rote_timestamps`."
            )
        if isinstance(rote_timestamps, list):
            rote_timestamps = torch.cat(rote_timestamps, dim=0)

        (
            input_features,
            feature_attention_mask,
            chunk_counts,
        ) = self._normalize_audio_feature_inputs(audio_input)
        hidden_states = self._encode_audio_features(
            input_features,
            feature_attention_mask,
        )
        cos, sin = self.pos_emb(
            rote_timestamps.to(hidden_states.device),
            seq_len=hidden_states.shape[-2],
        )
        hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
        audio_features = self.multi_modal_projector(hidden_states)

        return self._group_audio_embeddings(
            audio_features,
            feature_attention_mask,
            chunk_counts,
        )