"examples/vllm_v1/components/frontend.py" did not exist on "4fd4d53da0239e19d6d569634170985d11a32ab6"
gemma3n_mm.py 29.6 KB
Newer Older
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Any, Literal, Optional, Union, cast
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
5

6
import numpy as np
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
7
import torch
8
# yapf: disable
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
9
10
11
12
13
14
15
16
17
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (Gemma3nAudioConfig,
                                         Gemma3nAudioFeatureExtractor,
                                         Gemma3nConfig, Gemma3nProcessor,
                                         Gemma3nTextConfig,
                                         Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast

18
19
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
20
21
22
23
24
25
26
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
27
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
28
29
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalKwargsItems)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
31
32
33
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
                                   MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
35
36
                                        BaseProcessingInfo,
                                        MultiModalPromptUpdates,
                                        MultiModalPromptUpdatesApplyResult,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
37
                                        PlaceholderFeaturesInfo,
38
39
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
40
41
42
43
                                        replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
44
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
45

46
47
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
                         SupportsTranscription)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
48
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
49
                    init_vllm_registered_model, maybe_prefix)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
50
51
52
53
54
55
56
57

logger = init_logger(__name__)

# This should be based on model config but we hardcode them for now.
TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO = 188


58
59
60
61
62
63
64
65
66
67
class Gemma3nImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each patch
        - w: Width of each patch
    """
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
68
69


70
71
72
73
74
75
76
77
78
79
class Gemma3nAudioInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of audios
        - s: seq_length
        - f: num_features
    """
    type: Literal["audio"] = "audio"
    input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
    input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191


Gemma3nImageInputs = Gemma3nImagePixelInputs


class Gemma3nProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(Gemma3nConfig)

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None, "audio": None}

    def get_max_tokens_per_item(
            self, seq_len: int,
            mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:

        return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}

    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Gemma3nProcessor],
    ) -> str:
        """
        Get the replacement text for image tokens.
        
        For Gemma3n, this should return the full_image_sequence which includes
        BOI token, repeated image tokens, and EOI token.
        """
        if processor is None:
            processor = self.get_hf_processor()

        return PromptUpdateDetails.select_token_id(
            processor.full_image_sequence, processor.image_token_id)

    def get_audio_repl(
        self,
        *,
        processor: Optional[Gemma3nProcessor],
    ) -> str:
        """
        Get the replacement text for audio tokens.
        
        For Gemma3n, this should return the full_audio_sequence which includes
        BOA token, repeated audio tokens, and EOA token.
        """
        if processor is None:
            processor = self.get_hf_processor()

        # Return the full audio sequence as defined by the processor
        return PromptUpdateDetails.select_token_id(
            processor.full_audio_sequence, processor.audio_token_id)


class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):

    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        num_audios = mm_counts.get("audio", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token
        audio_token = processor.audio_token

        return image_token * num_images + audio_token * num_audios

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)
        num_audios = mm_counts.get("audio", 0)
        processor = self.info.get_hf_processor()
        audio_feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor  # noqa: E501
        audio_len = audio_feature_extractor.fft_length
        image_processor: SiglipImageProcessorFast = processor.image_processor
        img_width = image_processor.size.get("width", 224)
        img_height = image_processor.size.get("height", 224)

        return {
            "image":
            self._get_dummy_images(width=img_width,
                                   height=img_height,
                                   num_images=num_images),
            "audio":
            self._get_dummy_audios(length=audio_len, num_audios=num_audios)
        }


class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
                                 ):

    def _get_data_parser(self) -> MultiModalDataParser:
        feature_extractor = self.info.get_hf_processor().feature_extractor
        return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:

        # HF Transformers audio processor no longer accepts `audios` key.
co63oc's avatar
co63oc committed
192
        # We pop `audios` and replace it with `audio` key to suppress
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
193
194
195
196
197
198
199
200
201
        # the warning.
        if 'audios' in mm_data:
            mm_data['audio'] = mm_data.pop('audios')
        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
            tok_kwargs,
        )
202

Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
203
        if 'input_features' in processed_outputs:
204
205
206
207
208
            # Padding enables audio_tower to run in batched mode
            processed_outputs["input_features_padded"] = \
                processed_outputs["input_features"]

            # Unpad features here since we need the output of each item to be
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            # independent of other items for the cache to work correctly
            unpadded_features = [
                f[mask] for f, mask in zip(
                    processed_outputs["input_features"],
                    processed_outputs["input_features_mask"],
                )
            ]
            processed_outputs["input_features"] = unpadded_features
        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:

225
226
227
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            input_features_padded=MultiModalFieldConfig.batched("audio"),
228
229
            input_features_mask=MultiModalFieldConfig.batched("audio"),
        )
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
230
231
232
233
234

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
235
        out_mm_kwargs: MultiModalKwargsItems,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    ) -> Sequence[PromptUpdate]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        prompt_updates = []

        # Handle image tokens
        if "image" in mm_items:
            image_token = hf_processor.image_token

            def get_replacement_image(item_idx: int):
                images = mm_items.get_items("image", ImageProcessorItems)
                image_size = images.get_image_size(item_idx)
                return self.info.get_image_repl(
                    image_width=image_size.width,
                    image_height=image_size.height,
                    processor=hf_processor,
                )

            prompt_updates.append(
                PromptReplacement(
                    modality="image",
                    target=image_token,
                    replacement=get_replacement_image,
                ))

        # Handle audio tokens
        if "audio" in mm_items:
            audio_token = hf_processor.audio_token

            def get_replacement_audio(item_idx: int):
                return self.info.get_audio_repl(processor=hf_processor, )

            prompt_updates.append(
                PromptReplacement(
                    modality="audio",
                    target=audio_token,
                    replacement=get_replacement_audio,
                ))

        return prompt_updates

    def _apply_token_matches(
        self,
        prompt: list[int],
280
281
282
283
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
        token_ids, res = super()._apply_token_matches(prompt,
                                                      mm_prompt_updates)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

        # "\n\n\n" and "\n\n\n\n" are single tokens
        # Since our replacement can insert "\n\n" next to "\n"
        # tokens, we have to combine them to be consistent with
        # the output of the tokenizer
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
        newline_1 = vocab["\n"]
        newline_2 = vocab["\n\n"]
        newline_3 = vocab["\n\n\n"]
        newline_4 = vocab["\n\n\n\n"]

        token_ids = replace_token_matches(
            token_ids,
            [newline_1, newline_2],
            [newline_3],
        )
        token_ids = replace_token_matches(
            token_ids,
            [newline_2, newline_1],
            [newline_3],
        )
        token_ids = replace_token_matches(
            token_ids,
            [newline_2, newline_2],
            [newline_4],
        )

312
        return token_ids, res
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
313
314
315
316

    def _find_mm_placeholders(
        self,
        new_token_ids: list[int],
317
        mm_prompt_updates: MultiModalPromptUpdates,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
        # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
        newline_1 = vocab["\n"]
        newline_2 = vocab["\n\n"]
        newline_3 = vocab["\n\n\n"]
        newline_4 = vocab["\n\n\n\n"]

        def get_repl_toks(tok: int) -> list[int]:
            if tok == newline_3:
                return [newline_1, newline_2]
            if tok == newline_4:
                return [newline_2, newline_2]

            return [tok]

        repl_token_ids = list[int]()
        repl_orig_idxs = list[int]()
        for orig_idx, orig_tok in enumerate(new_token_ids):
            repl_toks = get_repl_toks(orig_tok)
            repl_token_ids.extend(repl_toks)
            repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))

342
343
        repls = super()._find_mm_placeholders(repl_token_ids,
                                              mm_prompt_updates)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

        return {
            modality: [
                PlaceholderFeaturesInfo(
                    modality=p.modality,
                    item_idx=p.item_idx,
                    start_idx=repl_orig_idxs[p.start_idx],
                    tokens=p.tokens,
                    is_embed=p.is_embed,
                ) for p in placeholders
            ]
            for modality, placeholders in repls.items()
        }


class Gemma3nMultimodalEmbedder(nn.Module):
    """Embeds token ids or soft tokens for multimodal content into language 
    model space."""

    def __init__(
        self,
        multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
        text_config: Gemma3nTextConfig,
    ):
        super().__init__()

        self.multimodal_hidden_size = multimodal_config.hidden_size
        self.eps = multimodal_config.rms_norm_eps
        self.vocab_offset = multimodal_config.vocab_offset
        self.vocab_size = multimodal_config.vocab_size
        self.text_hidden_size = text_config.hidden_size

        self.embedding = VocabParallelEmbedding(
            self.vocab_size,
            self.multimodal_hidden_size,
        )

        self.hard_embedding_norm = RMSNorm(
            self.multimodal_hidden_size,
            eps=self.eps,
        )

        self.soft_embedding_norm = RMSNorm(
            self.multimodal_hidden_size,
            eps=self.eps,
        )

        self.embedding_projection = RowParallelLinear(
            self.multimodal_hidden_size,
            self.text_hidden_size,
            bias=False,
        )

        self.embedding_post_projection_norm = RMSNorm(
            self.text_hidden_size,
            eps=self.eps,
            has_weight=False,
        )

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Embeds token ids or soft tokens for multimodal content into language model space.

        Args:
            input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
                `[vocab_offset, vocab_offset + vocab_size)`.
            inputs_embeds: A torch.Tensor containing the soft tokens to embed.

        Returns:
            A torch.Tensor of embeddings with  shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
        """  # noqa: E501
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You must specify exactly one of input_ids or inputs_embeds")

        if inputs_embeds is not None:
            emb_norm = self.soft_embedding_norm(inputs_embeds)
        else:
            hard_emb = self.embedding(input_ids - self.vocab_offset)
            emb_norm = self.hard_embedding_norm(hard_emb)

        emb_norm_proj, _ = self.embedding_projection(emb_norm)
        return self.embedding_post_projection_norm(emb_norm_proj)


@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor,
                                        info=Gemma3nProcessingInfo,
                                        dummy_inputs=Gemma3nDummyInputsBuilder)
435
436
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
                                      SupportsTranscription):
437
    merge_by_field_config = True
438
439
    supported_languages = ISO639_1_SUPPORTED_LANGS

Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.embed_audio.": "embed_audio.",
            "model.embed_vision.": "embed_vision.",
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.audio_tower.": "audio_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
            "model": "language_model.model",
        })

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
        self.vocab_size = config.text_config.vocab_size

        self.vision_tower = AutoModel.from_config(config=config.vision_config)
        self.audio_tower = AutoModel.from_config(config=config.audio_config)
        self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
                                                      config.text_config)
        self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
                                                     config.text_config)

        self.language_model: nn.Module = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Gemma3nForCausalLM"],
        )
        self.language_model = cast(Gemma3nForCausalLM, self.language_model)
        # NOTE (NickLucche) In order to be compatible with cudagraph, the
        # buffer needs to be consistent, so we pre-allocate here.
        self.per_layer_embeddings = torch.zeros(
            vllm_config.scheduler_config.max_num_batched_tokens,
            self.config.text_config.num_hidden_layers,
            self.config.text_config.hidden_size_per_layer_input,
            device=self.language_model.model.embed_tokens.weight.device,
            dtype=self.language_model.model.embed_tokens.weight.dtype)

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        # TODO is this the case?
        assert image_embeds is None, "Gemma3n does not support image_embeds."
        if pixel_values is None:
            return None

507
        return Gemma3nImagePixelInputs(pixel_values=pixel_values)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
508
509
510

    def _parse_and_validate_audio_input(
            self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
511
512
513

        input_features_padded = kwargs.pop("input_features_padded", None)
        if input_features_padded is None:
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
514
515
516
517
518
519
520
            return None

        input_features_mask = kwargs.pop("input_features_mask", None)
        if input_features_mask is None:
            return None

        return Gemma3nAudioInputs(
521
            input_features_padded=input_features_padded,
522
            input_features_mask=input_features_mask,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
523
524
525
526
527
528
529
530
531
532
533
534
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if input_key in ("pixel_values", "image_embeds"
                             ) and "image" not in mm_input_by_modality:
                mm_input_by_modality[
                    "image"] = self._parse_and_validate_image_input(**kwargs)
535
            if input_key == "input_features_padded" \
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
                and "audio" not in mm_input_by_modality:
                mm_input_by_modality[
                    "audio"] = self._parse_and_validate_audio_input(**kwargs)
        return mm_input_by_modality

    def _process_image_input(
        self,
        image_input: Gemma3nImageInputs,
    ) -> list[torch.Tensor]:
        assert self.vision_tower is not None

        pixel_values = image_input["pixel_values"]
        vision_outputs = self.vision_tower(pixel_values=pixel_values,
                                           do_pooling=False,
                                           return_dict=True).last_hidden_state
        # TODO try to avoid copy here
        # (batch, channels, height, width) to (batch, height * width, channels)
        vision_outputs = vision_outputs.reshape(
            vision_outputs.shape[0],
            self.config.vision_config.hidden_size,
            self.config.vision_soft_tokens_per_image,
        ).permute(0, 2, 1).contiguous()
        # Normalize and embed the soft tokens into language model space.
        vision_outputs *= self.config.vision_config.hidden_size**0.5
        # Return a list of embeddings instead of a batched tensor
        return self.embed_vision(inputs_embeds=vision_outputs).unbind(0)

    def _process_audio_input(
        self,
        audio_input: Gemma3nAudioInputs,
    ) -> list[torch.Tensor]:
        assert self.audio_tower is not None
568
569
        # Run on padded features to enable batching
        input_features = audio_input["input_features_padded"].squeeze(1)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
570
571
572
573
574
575
576
        input_features_mask = audio_input["input_features_mask"].squeeze(1)
        audio_outputs, audio_mask = self.audio_tower(input_features,
                                                     ~input_features_mask)
        audio_features = self.embed_audio(inputs_embeds=audio_outputs)

        # ruff: noqa
        # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
577
        # text to account for this. However, the audio preprocessing and encoder do not guarantee they will
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
578
579
        # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
        # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
580
        # the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        # TODO precompute and cache padding
        audio_padding_toks = torch.tensor([[self.vocab_size - 1]],
                                          dtype=torch.long,
                                          device=audio_features.device)
        audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
        audio_features = torch.where(audio_mask.unsqueeze(-1),
                                     audio_padding_embs, audio_features)

        audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
        extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len  # noqa: E501
        extra_padding_features = audio_padding_embs.expand(
            audio_batch_size, extra_padding_tokens, audio_embed_dim)

        audio_features = torch.cat((audio_features, extra_padding_features),
                                   dim=1)
        # Return a list of embeddings instead of a batched tensor
        return audio_features.unbind(0)

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
            **kwargs)
        if mm_input_by_modality is None:
            return []

        multimodal_embeddings: list[torch.Tensor] = []

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings.extend(vision_embeddings)
            if modality == "audio":
                audio_embeddings = self._process_audio_input(multimodal_input)
                multimodal_embeddings.extend(audio_embeddings)
        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
627
628
629
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
630
631
632
633
    ) -> torch.Tensor:
        # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
        # them here, as the model  forward has only access to the input_embeds.
        if input_ids is not None:
634
            per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
635
636
637
638
639
640
641
                input_ids)
            per_layer_inputs = per_layer_inputs.reshape(
                -1, self.config.text_config.num_hidden_layers,
                self.config.text_config.hidden_size_per_layer_input)
            self.per_layer_embeddings[:per_layer_inputs.shape[0]].copy_(
                per_layer_inputs)

642
643
644
645
646
647
648
649
650
651
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().get_input_embeddings(input_ids)

        return super().get_input_embeddings(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None,
                **kwargs: object) -> IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE (NickLucche) During profiling, `get_input_embeddings` is not
        # called, hence we don't have input_ids to compute PLEs. We simply
        # select a chunk of pre-allocated PLEs. During normal execution,
        # `get_input_embeddings` is called before forward, hence this slice
        # will contain PLEs computed from the actual input_ids.
        per_layer_inputs = self.per_layer_embeddings[:inputs_embeds.shape[0]]

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            per_layer_inputs=per_layer_inputs,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs)

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
683
        return self.language_model.compute_logits(hidden_states)
Nicolò Lucchesi's avatar
Nicolò Lucchesi committed
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower")

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality == "image":
            return "<image_soft_token>"
        elif modality == "audio":
            return "<audio_soft_token>"
        else:
            raise ValueError(f"Unsupported modality: {modality}")
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756

    @classmethod
    def get_generation_prompt(cls, audio: np.ndarray,
                              stt_config: SpeechToTextConfig,
                              model_config: ModelConfig,
                              language: Optional[str],
                              task_type: Literal["transcribe", "translate"],
                              request_prompt: str,
                              to_language: Optional[str]) -> PromptType:
        """
        Gemma3n supports "free-form" transcription.
        We fix its prompt here to standardize transcriptions/translations 
        requests.
        """
        # Transcribe this audio [into <>] | for transcription
        # Translate this audio [from <> into <>] | for translation
        prompt = "<start_of_turn>user\n"
        prompt += "Transcribe" if task_type == "transcribe" else "Translate"
        prompt += " this audio"

        # We assume the language is a valid ISO 639-1 code.
        full_lang_name = cls.supported_languages.get(language, "")
        # Translation only for now
        full_lang_name_to = cls.supported_languages.get(to_language, "")

        if task_type == "transcribe" and full_lang_name:
            prompt += f" into {full_lang_name}"
        elif task_type == "translate":
            if full_lang_name:
                prompt += f" from {full_lang_name}"
            if full_lang_name_to:
                prompt += f" into {full_lang_name_to}"

        prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"

        audio = (audio, stt_config.sample_rate)
        prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
        return cast(PromptType, prompts_dict)

    @classmethod
    def get_speech_to_text_config(cls, model_config: ModelConfig,
                                  task_type: str) -> SpeechToTextConfig:
        return SpeechToTextConfig(
            # Let's set this to 30 as suggested in the docs for now, although
            # the model is only limited by its context length.
            max_audio_clip_s=30,
            sample_rate=16000,
            # TODO enable chunking after more thorough testing.
            min_energy_split_window_size=None,
        )