nemotron_vl.py 32.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
import math
11
12
13
14
15
from abc import ABC
from collections.abc import Iterable

import torch
import torch.nn as nn
16
import torchvision.transforms as T
17
18
19
20
21
from PIL import Image
from transformers import AutoModel, PretrainedConfig
from transformers.image_processing_utils_fast import BaseImageProcessorFast

from vllm.config import VllmConfig
22
from vllm.model_executor.layers.linear import ReplicatedLinear
23
from vllm.model_executor.layers.pooler import DispatchPooler
24
25
26
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.internvl import (
27
28
29
30
31
32
33
34
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    InternVLImageEmbeddingInputs,
    InternVLImageInputs,
    InternVLImagePixelInputs,
    InternVLProcessor,
)
35
from vllm.model_executor.models.module_mapping import MultiModelKeys
36
from vllm.model_executor.models.siglip import SiglipVisionModel
37
from vllm.multimodal import MULTIMODAL_REGISTRY
38
from vllm.multimodal.image import convert_image_mode
39
40
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
41
from vllm.tokenizers import TokenizerLike
42
from vllm.transformers_utils.processor import cached_image_processor_from_config
43
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
44

45
46
from .interfaces import (
    MultiModalEmbeddings,
47
    SupportsCrossEncoding,
48
49
50
51
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
52
53
54
55
56
57
58
from .interfaces_base import VllmModelForPooling
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
59
60


61
def build_transform(input_size: int):
62
63
64
65
66
67
68
69
70
    return T.Compose(
        [
            T.Lambda(lambda img: convert_image_mode(img, "RGB")),
            T.Resize(
                (input_size, input_size), interpolation=T.InterpolationMode.BICUBIC
            ),
            T.ToTensor(),
        ]
    )
71
72
73
74
75
76
77
78
79
80
81


# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1
def find_closest_aspect_ratio(
    aspect_ratio: float,
    target_ratios: list[tuple[int, int]],
    *,
    width: int,
    height: int,
    image_size: int,
) -> tuple[int, int]:
82
    best_factor = float("-inf")
83
84
85
86
87
88
    best_ratio = (1, 1)
    area = width * height

    for rw, rh in target_ratios:
        target_aspect_ratio = rw / rh
        size_factor = min((rw * rh * image_size * image_size) / area, 0.6)
89
90
91
        ratio_closeness = min(
            target_aspect_ratio / aspect_ratio, aspect_ratio / target_aspect_ratio
        )
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        factor = size_factor * ratio_closeness

        if factor > best_factor:
            best_factor = factor
            best_ratio = (rw, rh)

    return best_ratio


def calculate_nemotron_vl_targets(
    *,
    orig_width: int,
    orig_height: int,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> tuple[int, int, int]:
    aspect_ratio = orig_width / orig_height

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio,
        target_ratios,
        width=orig_width,
        height=orig_height,
        image_size=image_size,
    )

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # add thumbnail image if num_blocks != 1
    if use_thumbnail and blocks != 1:
        blocks += 1

    return blocks, target_width, target_height


def dynamic_preprocess_nemotron_vl(
    image: Image.Image,
    *,
    target_ratios: list[tuple[int, int]],
    image_size: int,
    use_thumbnail: bool,
) -> list[Image.Image]:
    orig_width, orig_height = image.size

    # calculate the number of blocks without thumbnail
    blocks, target_width, target_height = calculate_nemotron_vl_targets(
        orig_width=orig_width,
        orig_height=orig_height,
        target_ratios=target_ratios,
        image_size=image_size,
        use_thumbnail=False,
    )

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
154
155
156
157
158
159
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size,
        )
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)

    assert len(processed_images) == blocks

    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)

    return processed_images


def get_nemotron_vl_target_ratios(
    min_num: int,
    max_num: int,
) -> list[tuple[int, int]]:
177
178
179
180
181
182
183
    target_ratios = {
        (i, j)
        for n in range(min_num, max_num + 1)
        for i in range(1, n + 1)
        for j in range(1, n + 1)
        if min_num <= i * j <= max_num
    }
184
185
186
187
188
189
190
191
192
193
    return sorted(target_ratios, key=lambda x: x[0] * x[1])


def image_to_pixel_values_nemotron_vl(
    image: Image.Image,
    *,
    input_size: int,
    min_num: int,
    max_num: int,
    use_thumbnail: bool,
194
    transform: T.Compose | None = None,
195
196
197
) -> torch.Tensor:
    target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)

198
199
    if transform is None:
        transform = build_transform(input_size=input_size)
200
201
202
203
204
205
206
207
208
209
210
211

    images = dynamic_preprocess_nemotron_vl(
        image,
        target_ratios=target_ratios,
        image_size=input_size,
        use_thumbnail=use_thumbnail,
    )

    pixel_values = torch.stack([transform(image) for image in images])
    return pixel_values


212
class NemotronVLProcessor(InternVLProcessor):
213
214
215
216
    IMG_START = "<img>"
    IMG_END = "</img>"
    IMG_CONTEXT = "<image>"

217
218
219
    def __init__(
        self,
        config: PretrainedConfig,
220
        tokenizer: TokenizerLike,
221
        image_processor: BaseImageProcessorFast | None = None,
222
        *,
223
224
225
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    ) -> None:
        ABC.__init__(self)
        self.config = config
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        image_size: int = config.force_image_size
        patch_size: int = config.patch_size

        if min_dynamic_patch is None:
            min_dynamic_patch = 1
        assert isinstance(min_dynamic_patch, int)

        if max_dynamic_patch is None:
            max_dynamic_patch = self.image_processor.max_num_tiles
        assert isinstance(max_dynamic_patch, int)

        if dynamic_image_size is None:
            dynamic_image_size = True
        assert isinstance(dynamic_image_size, bool)

        self.num_image_token = int(
247
248
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
249
250
251
252
        self.image_size = image_size
        self.min_dynamic_patch = min_dynamic_patch
        self.max_dynamic_patch = max_dynamic_patch
        self.dynamic_image_size = dynamic_image_size
253
254
255
256
257

        if image_processor is not None:
            self.use_thumbnail = image_processor.use_thumbnail
        else:
            self.use_thumbnail = getattr(config, "use_thumbnail", True)
258
259
260

    @property
    def image_token_id(self) -> int:
261
262
263
264
        return self.tokenizer.get_vocab()[self.IMG_CONTEXT]

    def _get_transform(self) -> T.Compose:
        return build_transform(input_size=self.image_size)
265

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        target_ratios = self.resolve_target_ratios(
            use_thumbnail=False,  # Applied in calculate_targets
        )

        num_patches, _, _ = calculate_nemotron_vl_targets(
            orig_width=image_width,
            orig_height=image_height,
            image_size=self.image_size,
            target_ratios=target_ratios,
            use_thumbnail=self.use_thumbnail,
        )

        return num_patches * self.num_image_token

    def _images_to_pixel_values_lst(
        self,
        images: list[Image.Image],
289
290
291
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    ) -> list[torch.Tensor]:
        min_num, max_num = self.resolve_min_max_num(
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
            use_thumbnail=False,  # Applied in image_to_pixel_values
        )

        return [
            image_to_pixel_values_nemotron_vl(
                image,
                input_size=self.image_size,
                min_num=min_num,
                max_num=max_num,
                use_thumbnail=self.use_thumbnail,
307
                transform=self._get_transform(),
308
309
            )
            for image in images
310
311
        ]

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    def _replace_image_tokens(
        self,
        text: list[str],
        pixel_values_lst: list[torch.Tensor],
    ) -> list[str]:
        """Replace <image> placeholders with image tokens."""
        for pixel_values in pixel_values_lst:
            num_patches = pixel_values.shape[0]
            feature_size = num_patches * self.num_image_token
            image_repl = self.get_image_repl(feature_size, num_patches)
            # Use temporary placeholder to avoid replacing tokens we just inserted
            NVL_IMAGE_CONTEXT = image_repl.full.replace("<image>", "<NVL_IMG_CONTEXT>")
            text = [t.replace("<image>", NVL_IMAGE_CONTEXT, 1) for t in text]
        return [t.replace("<NVL_IMG_CONTEXT>", self.IMG_CONTEXT) for t in text]

327
328
329
330
    def _preprocess_image(
        self,
        text: list[str],
        images: list[Image.Image],
331
332
333
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
334
335
336
337
338
339
340
341
342
343
    ) -> tuple[list[str], dict[str, torch.Tensor]]:
        if len(images) == 0:
            image_inputs = {}
        else:
            pixel_values_lst = self._images_to_pixel_values_lst(
                images,
                min_dynamic_patch=min_dynamic_patch,
                max_dynamic_patch=max_dynamic_patch,
                dynamic_image_size=dynamic_image_size,
            )
344
            image_inputs = {
345
346
347
348
                "pixel_values_flat": torch.cat(pixel_values_lst),
                "image_num_patches": torch.tensor(
                    [len(item) for item in pixel_values_lst]
                ),
349
350
            }

351
            text = self._replace_image_tokens(text, pixel_values_lst)
352
353
354
355
356
        return text, image_inputs

    def get_image_repl(
        self,
        feature_size: int,
357
        num_patches: int | None,
358
    ) -> PromptUpdateDetails[str]:
359
360
        repl_features = self.IMG_CONTEXT * feature_size
        repl_full = self.IMG_START + repl_features + self.IMG_END
361

362
        return PromptUpdateDetails.select_text(repl_full, self.IMG_CONTEXT)
363
364
365
366
367


class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
    """Processing info for Nemotron VL models."""

368
    def get_hf_processor(self, **kwargs: object) -> NemotronVLProcessor:
369
370
371
372
        return self.ctx.init_processor(
            NemotronVLProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
373
            image_processor=self.get_image_processor(),
374
375
376
            **kwargs,
        )

377
    def get_image_processor(self, **kwargs: object):
378
        return cached_image_processor_from_config(
379
            self.ctx.model_config,
380
381
382
383
384
385
386
            **kwargs,
        )


@MULTIMODAL_REGISTRY.register_processor(
    BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo],
    info=NemotronVLProcessingInfo,
387
388
389
    dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo],
)
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
390
    @classmethod
391
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
392
393
394
395
396
397
398
399
400
401
402
403
404
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
405
        self.model_config = vllm_config.model_config
406
407
408
409
410
411
412
        self.multimodal_config = multimodal_config
        self._patch_quant_config(config, quant_config)

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
413
414
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
415
416
417
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

418
419
420
421
422
423
424
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(config)
425

426
427
428
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
429
                hf_config=config.get_text_config(),
430
431
                prefix=maybe_prefix(prefix, "language_model"),
            )
432
433
434
435
436

        self.img_context_token_id = None

        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
437
438
            self.language_model.make_empty_intermediate_tensors
        )
439

440
441
442
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
443
444
445
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
446
            text_config = config.get_text_config()
447
448
449
450
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
451
452
453
454
455
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
456
        quant_config: QuantizationConfig | None,
457
458
459
        *,
        prefix: str,
    ):
460
461
462
463
        return AutoModel.from_config(
            config.vision_config,
            trust_remote_code=self.model_config.trust_remote_code,
        )
464

465
466
467
468
469
470
471
472
473
474
475
    def _init_mlp1(
        self,
        config: PretrainedConfig,
        vit_hidden_size: int | None = None,
        vision_projection_hidden_size: int | None = None,
    ) -> nn.Module:
        if vit_hidden_size is None:
            vit_hidden_size = config.vit_hidden_size
        if vision_projection_hidden_size is None:
            vision_projection_hidden_size = config.projector_hidden_size
        llm_hidden_size = config.get_text_config().hidden_size
476
477

        return nn.Sequential(
478
479
480
481
482
483
484
485
            nn.LayerNorm(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2, bias=True
            ),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                vision_projection_hidden_size,
                bias=True,
            ),
486
487
488
489
490
491
492
493
494
495
            nn.GELU(),
            nn.Linear(vision_projection_hidden_size, llm_hidden_size),
        )

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
496
497
498
499
500
501
502
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
503
504
505
506
507
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

508
509
510
511
512
513
514
515
516
    def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Call vision model and return embeddings.

        Override this method in subclasses to handle different vision model
        interfaces (e.g., SigLIP vs C-RADIO).
        """
        vit_embeds = self.vision_model(x=pixel_values).features
        return vit_embeds.to(dtype=torch.bfloat16)

517
518
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
519
        vit_embeds = self._call_vision_model(pixel_values)
520

521
        h = w = int(vit_embeds.shape[1] ** 0.5)
522
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
523
524
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
525
526
527
528
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
529
        self, **kwargs: object
530
    ) -> InternVLImageInputs | None:
531
532
533
534
535
536
537
538
539
540
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
541
                data=image_embeds,
542
543
544
            )

        image_token_id = kwargs["image_token_id"]
545
546
547
548
549
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
550
551
552
553

        if pixel_values_flat is not None:
            return InternVLImagePixelInputs(
                type="pixel_values",
554
                pixel_values_flat=pixel_values_flat,
555
                num_patches=image_num_patches,
556
557
                resolve_bindings={
                    "h": self.config.force_image_size,
558
                    "w": self.config.force_image_size,
559
                },
560
561
562
563
564
565
566
567
568
569
570
571
572
573
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_embeds = self.extract_feature(image_input["pixel_values_flat"])

        num_patches = image_input["num_patches"]
574
        hidden_size = self.config.get_text_config().hidden_size
575
576
577

        # Only one image in the current batch
        if len(num_patches) == 1:
578
            return (image_embeds.view(-1, hidden_size),)
579
580
581
582

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
583
        image_embeds = image_embeds.view(-1, hidden_size)
584
585
586
587
588
589
590
591
592
593
594
        image_feature_sizes = [
            num_patches * feature_size for num_patches in num_patches
        ]
        return image_embeds.split(image_feature_sizes)

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
595
596
597
598
599
            if (
                input_key in ("pixel_values_flat", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
600
601
602
603
604
605

        return modalities

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
        self.visual_token_mask = None

606
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
607
608
609
610
611
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
612
        # tensor corresponding to a multimodal data item (image).
613
614
615
616
617
618
619
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
620
621
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
622
623
624

        return multimodal_embeddings

625
    def embed_input_ids(
626
627
        self,
        input_ids: torch.Tensor,
628
        multimodal_embeddings: MultiModalEmbeddings | None = None,
629
        *,
630
        is_multimodal: torch.Tensor | None = None,
631
    ) -> torch.Tensor:
632
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
633
            self._set_visual_token_mask(input_ids)
634
635
636

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
637
            return super().embed_input_ids(input_ids)
638

639
        return super().embed_input_ids(
640
641
642
643
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
644
645
646

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
647
        input_ids: torch.Tensor | None,
648
        positions: torch.Tensor,
649
650
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
651
652
653
654
655
656
657
658
659
660
661
662
663
664
        **kwargs: object,
    ) -> IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        # Only required if the model is mono-architecture
        if self.visual_token_mask is not None:
665
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
666
667
668
669
670
671
672
673
            self.visual_token_mask = None

        hidden_states = self.language_model.model(**forward_kwargs)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
674
    ) -> torch.Tensor | None:
675
        return self.language_model.compute_logits(hidden_states)
676

677
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
678
679
680
681
682
683
684
685
686
687
688
689
690
        ## Ignore registered_buffers
        ## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501
        skip_substrs = ["norm_mean", "norm_std"]
        loader = AutoWeightsLoader(self, skip_substrs=skip_substrs)
        return loader.load_weights(weights)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp1",
691
692
            tower_model="vision_model",
        )
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890


# --------------------------------------------------------
# LlamaNemotronVL Embedding Model (nvidia/llama-nemotron-embed-vl-1b-v2)
# Extends LlamaNemotronVLChatModel for embedding/pooling tasks:
#   - SigLIP vision encoder (instead of C-RADIO)
#   - Bidirectional (non-causal) LLaMA language model
#   - Pooler output instead of generative logits
# --------------------------------------------------------

# SigLIP normalization constants
SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)


def build_siglip_transform(input_size: int):
    """Build transform for SigLIP vision encoder with normalization.

    Extends the base transform from nemotron_vl with SigLIP-specific normalization.
    """
    base_transform = build_transform(input_size=input_size)
    return T.Compose(
        [
            base_transform,
            T.Normalize(mean=SIGLIP_MEAN, std=SIGLIP_STD),
        ]
    )


class LlamaNemotronVLEmbedProcessor(NemotronVLProcessor):
    """
    Processor for LlamaNemotronVL embedding model.

    Inherits from NemotronVLProcessor and specializes it for embedding tasks:
    - Uses SigLIP transform with normalization instead of base transform
    - Uses different image context token (<IMG_CONTEXT> vs <image>)
    """

    IMG_CONTEXT = "<IMG_CONTEXT>"

    def __init__(
        self,
        config: PretrainedConfig,
        tokenizer: TokenizerLike,
        processor_config: dict,
        *,
        min_dynamic_patch: int | None = None,
        max_dynamic_patch: int | None = None,
        dynamic_image_size: bool | None = None,
    ) -> None:
        if min_dynamic_patch is None:
            min_dynamic_patch = processor_config.get(
                "min_input_tiles",
                getattr(config, "min_dynamic_patch", 1),
            )
        if max_dynamic_patch is None:
            max_dynamic_patch = processor_config.get(
                "max_input_tiles",
                getattr(config, "max_dynamic_patch", 1),
            )
        if dynamic_image_size is None:
            dynamic_image_size = processor_config.get(
                "dynamic_image_size",
                getattr(config, "dynamic_image_size", True),
            )
        super().__init__(
            config=config,
            tokenizer=tokenizer,
            image_processor=None,
            min_dynamic_patch=min_dynamic_patch,
            max_dynamic_patch=max_dynamic_patch,
            dynamic_image_size=dynamic_image_size,
        )

    def _get_transform(self) -> T.Compose:
        """Override to add SigLIP normalization."""
        return build_siglip_transform(input_size=self.image_size)

    def _replace_image_tokens(
        self,
        text: list[str],
        pixel_values_lst: list[torch.Tensor],
    ) -> list[str]:
        """Override with simpler token replacement for embedding model.

        No temporary placeholder needed because IMG_CONTEXT is <IMG_CONTEXT>,
        not <image>, so there's no collision risk.
        """
        for pixel_values in pixel_values_lst:
            num_patches = pixel_values.shape[0]
            feature_size = num_patches * self.num_image_token
            image_repl = self.get_image_repl(feature_size, num_patches)
            text = [t.replace("<image>", image_repl.full, 1) for t in text]
        return text


class LlamaNemotronVLEmbedProcessingInfo(NemotronVLProcessingInfo):
    """Processing info for LlamaNemotronVL embedding model."""

    def get_hf_processor(self, **kwargs: object) -> LlamaNemotronVLEmbedProcessor:
        """Override to create embedding-specific processor without image_processor."""
        model_config = self.ctx.model_config
        processor_config = {}
        if model_config.model is not None:
            processor_config = (
                get_hf_file_to_dict(
                    "processor_config.json",
                    model_config.model,
                    model_config.revision,
                )
                or {}
            )

        return self.ctx.init_processor(
            LlamaNemotronVLEmbedProcessor,
            config=self.get_hf_config(),
            tokenizer=self.get_tokenizer(),
            processor_config=processor_config,
            **kwargs,
        )


@MULTIMODAL_REGISTRY.register_processor(
    BaseInternVLMultiModalProcessor[LlamaNemotronVLEmbedProcessingInfo],
    info=LlamaNemotronVLEmbedProcessingInfo,
    dummy_inputs=BaseInternVLDummyInputsBuilder[LlamaNemotronVLEmbedProcessingInfo],
)
class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling):
    """
    LlamaNemotronVL model for embeddings.

    Inherits from LlamaNemotronVLChatModel and specializes it for embedding tasks:
    - Uses SigLIP vision encoder instead of C-RADIO
    - Uses bidirectional LLaMA (via llm_config) instead of causal LLaMA
    - Adds pooler for embedding output instead of generating logits
    """

    is_pooling_model = True

    # Weight mapping from checkpoint format to vLLM format
    # Different from parent class due to different vision model structure
    weight_mapper = WeightsMapper(
        orig_to_new_prefix={
            # Language model mapping
            "language_model.layers.": "language_model.model.layers.",
            "language_model.embed_tokens.": "language_model.model.embed_tokens.",
            "language_model.norm.": "language_model.model.norm.",
            # Vision model mapping (SiglipVisionModel has nested vision_model)
            "vision_model.encoder.": "vision_model.vision_model.encoder.",
            "vision_model.embeddings.": "vision_model.vision_model.embeddings.",
            "vision_model.post_layernorm.": "vision_model.vision_model.post_layernorm.",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        config = vllm_config.model_config.hf_config

        # Override: get img_context_token_id from config (parent sets None)
        self.img_context_token_id = getattr(config, "img_context_token_id", None)

        # Initialize pooler for embedding output
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = DispatchPooler.for_embedding(pooler_config)

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config,
        *,
        prefix: str,
    ) -> nn.Module:
        """Override to use SigLIP instead of C-RADIO."""
        return SiglipVisionModel(
            config.vision_config,
            quant_config=quant_config,
            prefix=prefix,
            use_head=False,
        )

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
        """Override to use different MLP structure for embedding model."""
        return super()._init_mlp1(
            config,
            vit_hidden_size=config.vision_config.hidden_size,
            vision_projection_hidden_size=config.get_text_config().hidden_size,
        )

    def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Override to handle SigLIP interface."""
        return self.vision_model(pixel_values)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Override to use different weight mapping for SigLIP."""
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.weight_mapper)
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943


class LlamaNemotronVLForSequenceClassification(
    LlamaNemotronVLForEmbedding, SupportsCrossEncoding
):
    """LlamaNemotronVL model variant for sequence classification / reranking."""

    # Reranker checkpoint places base model weights under `model.*`,
    # while `score.*` remains at the top level.
    weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
        LlamaNemotronVLForEmbedding.weight_mapper
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        text_config = vllm_config.model_config.hf_config.get_text_config()
        model_config = vllm_config.model_config
        quant_config = vllm_config.quant_config

        self.score = ReplicatedLinear(
            model_config.get_hidden_size(),
            text_config.num_labels,
            bias=False,
            params_dtype=model_config.head_dtype,
            quant_config=quant_config,
            return_bias=False,
            prefix=maybe_prefix(prefix, "score"),
        )

        pooler_config = model_config.pooler_config
        assert pooler_config is not None
        self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loaded_weights = super().load_weights(weights)

        # reranker checkpoint omits the inner LM seq-cls head
        # (`language_model.score.*`). It is unused by this outer model, but
        # the default loader expects all parameters to be initialized.
        for name, param in self.named_parameters():
            if not name.startswith("language_model.score.") or name in loaded_weights:
                continue

            if name.endswith(".weight"):
                torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
            elif name.endswith(".bias"):
                torch.nn.init.zeros_(param)
            else:
                torch.nn.init.normal_(param, mean=0.0, std=0.02)

            loaded_weights.add(name)

zhuwenwen's avatar
zhuwenwen committed
944
        return loaded_weights