nemotron_vl.py 22 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import math
4
5
6
7
8
9
10
from collections.abc import Iterable

import torch
import torch.nn as nn
from transformers import AutoModel, PretrainedConfig

from vllm.config import VllmConfig
11
from vllm.model_executor.layers.linear import ReplicatedLinear
12
from vllm.model_executor.layers.pooler import DispatchPooler
13
14
15
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.internvl import (
16
17
18
19
20
21
22
    BaseInternVLDummyInputsBuilder,
    BaseInternVLMultiModalProcessor,
    BaseInternVLProcessingInfo,
    InternVLImageEmbeddingInputs,
    InternVLImageInputs,
    InternVLImagePixelInputs,
)
23
from vllm.model_executor.models.module_mapping import MultiModelKeys
24
from vllm.model_executor.models.siglip import SiglipVisionModel
25
26
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
27
from vllm.transformers_utils.processor import cached_image_processor_from_config
28
from vllm.transformers_utils.processors.nemotron_vl import (
29
30
31
    LlamaNemotronNanoVLImageProcessor,
    LlamaNemotronNanoVLProcessor,
    LlamaNemotronVLEmbedImageProcessor,
32
33
    LlamaNemotronVLEmbedProcessor,
)
34
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
35

36
37
from .interfaces import (
    MultiModalEmbeddings,
38
    SupportsCrossEncoding,
39
40
41
42
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
43
44
45
46
47
48
49
from .interfaces_base import VllmModelForPooling
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
50
51
52
53
54


class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
    """Processing info for Nemotron VL models."""

55
56
57
58
    def get_image_processor(self, **kwargs: object):
        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        orig_processor = cached_image_processor_from_config(
            self.ctx.model_config, **kwargs
59
60
        )

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return LlamaNemotronNanoVLImageProcessor(
            image_size=orig_processor.image_size,
            min_dynamic_patch=1,
            max_dynamic_patch=orig_processor.max_num_tiles,
            dynamic_image_size=True,
            use_thumbnail=orig_processor.use_thumbnail,
        )

    def get_hf_processor(self, **kwargs: object) -> LlamaNemotronNanoVLProcessor:
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))

        return LlamaNemotronNanoVLProcessor(
            tokenizer=self.get_tokenizer(),
            image_processor=image_processor,
            image_seq_length=image_seq_length,
83
84
85
86
87
88
        )


@MULTIMODAL_REGISTRY.register_processor(
    BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo],
    info=NemotronVLProcessingInfo,
89
90
91
    dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo],
)
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
92
    @classmethod
93
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
94
95
96
97
98
99
100
101
102
103
104
105
106
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
107
        self.model_config = vllm_config.model_config
108
109
110
111
112
113
114
        self.multimodal_config = multimodal_config
        self._patch_quant_config(config, quant_config)

        image_size = config.force_image_size or config.vision_config.image_size
        patch_size = config.vision_config.patch_size
        self.patch_size = patch_size
        self.num_image_token = int(
115
116
            (image_size // patch_size) ** 2 * (config.downsample_ratio**2)
        )
117
118
119
        self.downsample_ratio = config.downsample_ratio
        self.ps_version = config.ps_version

120
121
122
123
124
125
126
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = self._init_vision_model(
                config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
            )
            self.mlp1 = self._init_mlp1(config)
127

128
129
130
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
131
                hf_config=config.get_text_config(),
132
133
                prefix=maybe_prefix(prefix, "language_model"),
            )
134
135
136
137
138

        self.img_context_token_id = None

        self.visual_token_mask = None
        self.make_empty_intermediate_tensors = (
139
140
            self.language_model.make_empty_intermediate_tensors
        )
141

142
143
144
    def _patch_quant_config(
        self, config: PretrainedConfig, quant_config: QuantizationConfig
    ):
145
146
147
        # the awq models from OpenGVLab missing `modules_to_not_convert`
        # patch the quant_config to add `modules_to_not_convert` back
        if isinstance(quant_config, AWQConfig):
148
            text_config = config.get_text_config()
149
150
151
152
            llm_quant_config = getattr(text_config, "quantization_config", None)
            if (not quant_config.modules_to_not_convert) and (
                llm_quant_config is not None
            ):
153
154
155
156
157
                quant_config.modules_to_not_convert.append("vision_model")

    def _init_vision_model(
        self,
        config: PretrainedConfig,
158
        quant_config: QuantizationConfig | None,
159
160
161
        *,
        prefix: str,
    ):
162
163
164
165
        return AutoModel.from_config(
            config.vision_config,
            trust_remote_code=self.model_config.trust_remote_code,
        )
166

167
168
169
170
171
172
173
174
175
176
177
    def _init_mlp1(
        self,
        config: PretrainedConfig,
        vit_hidden_size: int | None = None,
        vision_projection_hidden_size: int | None = None,
    ) -> nn.Module:
        if vit_hidden_size is None:
            vit_hidden_size = config.vit_hidden_size
        if vision_projection_hidden_size is None:
            vision_projection_hidden_size = config.projector_hidden_size
        llm_hidden_size = config.get_text_config().hidden_size
178
179

        return nn.Sequential(
180
181
182
183
184
185
186
187
            nn.LayerNorm(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2, bias=True
            ),
            nn.Linear(
                vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
                vision_projection_hidden_size,
                bias=True,
            ),
188
189
190
191
192
193
194
195
196
197
            nn.GELU(),
            nn.Linear(vision_projection_hidden_size, llm_hidden_size),
        )

    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.size()
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
198
199
200
201
202
203
204
        x = x.view(
            n,
            int(h * scale_factor),
            int(w * scale_factor),
            int(c / (scale_factor * scale_factor)),
        )
        if self.ps_version == "v1":
205
206
207
208
209
            pass
        else:
            x = x.permute(0, 2, 1, 3).contiguous()
        return x

210
211
212
213
214
215
216
217
218
    def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Call vision model and return embeddings.

        Override this method in subclasses to handle different vision model
        interfaces (e.g., SigLIP vs C-RADIO).
        """
        vit_embeds = self.vision_model(x=pixel_values).features
        return vit_embeds.to(dtype=torch.bfloat16)

219
220
    def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177
221
        vit_embeds = self._call_vision_model(pixel_values)
222

223
        h = w = int(vit_embeds.shape[1] ** 0.5)
224
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
225
226
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
227
228
229
230
        vit_embeds = self.mlp1(vit_embeds)
        return vit_embeds

    def _parse_and_validate_image_input(
231
        self, **kwargs: object
232
    ) -> InternVLImageInputs | None:
233
234
235
236
237
238
239
240
241
242
        pixel_values_flat = kwargs.pop("pixel_values_flat", None)
        image_num_patches = kwargs.pop("image_num_patches", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values_flat is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return InternVLImageEmbeddingInputs(
                type="image_embeds",
243
                data=image_embeds,
244
245
246
            )

        image_token_id = kwargs["image_token_id"]
247
248
249
250
251
        if isinstance(image_token_id, torch.Tensor):
            image_token_id = image_token_id.flatten().unique().item()

        assert isinstance(image_token_id, int)
        self.img_context_token_id = image_token_id
252
253
254
255

        if pixel_values_flat is not None:
            return InternVLImagePixelInputs(
                type="pixel_values",
256
                pixel_values_flat=pixel_values_flat,
257
                num_patches=image_num_patches,
258
259
                resolve_bindings={
                    "h": self.config.force_image_size,
260
                    "w": self.config.force_image_size,
261
                },
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: InternVLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_embeds = self.extract_feature(image_input["pixel_values_flat"])

        num_patches = image_input["num_patches"]
276
        hidden_size = self.config.get_text_config().hidden_size
277
278
279

        # Only one image in the current batch
        if len(num_patches) == 1:
280
            return (image_embeds.view(-1, hidden_size),)
281
282
283
284

        # NOTE: Image embeddings are split into separate tensors for each image
        # by the size of each embedding.
        feature_size = image_embeds.shape[1]
285
        image_embeds = image_embeds.view(-1, hidden_size)
286
287
288
289
290
291
292
293
294
295
296
        image_feature_sizes = [
            num_patches * feature_size for num_patches in num_patches
        ]
        return image_embeds.split(image_feature_sizes)

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        modalities = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
297
298
299
300
301
            if (
                input_key in ("pixel_values_flat", "image_embeds")
                and "images" not in modalities
            ):
                modalities["images"] = self._parse_and_validate_image_input(**kwargs)
302
303
304
305
306
307

        return modalities

    def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
        self.visual_token_mask = None

308
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
309
310
311
312
313
        modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not modalities:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
314
        # tensor corresponding to a multimodal data item (image).
315
316
317
318
319
320
321
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in modalities:
            if modality == "images":
                image_input = modalities["images"]
322
323
                image_embeddings = self._process_image_input(image_input)
                multimodal_embeddings += tuple(image_embeddings)
324
325
326

        return multimodal_embeddings

327
    def embed_input_ids(
328
329
        self,
        input_ids: torch.Tensor,
330
        multimodal_embeddings: MultiModalEmbeddings | None = None,
331
        *,
332
        is_multimodal: torch.Tensor | None = None,
333
    ) -> torch.Tensor:
334
        if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
335
            self._set_visual_token_mask(input_ids)
336
337
338

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
339
            return super().embed_input_ids(input_ids)
340

341
        return super().embed_input_ids(
342
343
344
345
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
346
347
348

    def forward(
        self,
349
        input_ids: torch.Tensor | None,
350
        positions: torch.Tensor,
351
352
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        **kwargs: object,
    ) -> IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        forward_kwargs = {
            "input_ids": input_ids,
            "positions": positions,
            "intermediate_tensors": intermediate_tensors,
            "inputs_embeds": inputs_embeds,
        }

        # Only required if the model is mono-architecture
        if self.visual_token_mask is not None:
367
            forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
368
369
370
371
372
373
374
375
            self.visual_token_mask = None

        hidden_states = self.language_model.model(**forward_kwargs)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
376
    ) -> torch.Tensor | None:
377
        return self.language_model.compute_logits(hidden_states)
378

379
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
380
381
382
383
384
385
386
387
388
389
390
391
392
        ## Ignore registered_buffers
        ## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501
        skip_substrs = ["norm_mean", "norm_std"]
        loader = AutoWeightsLoader(self, skip_substrs=skip_substrs)
        return loader.load_weights(weights)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="mlp1",
393
394
            tower_model="vision_model",
        )
395
396
397
398
399
400
401
402
403
404
405


# --------------------------------------------------------
# LlamaNemotronVL Embedding Model (nvidia/llama-nemotron-embed-vl-1b-v2)
# Extends LlamaNemotronVLChatModel for embedding/pooling tasks:
#   - SigLIP vision encoder (instead of C-RADIO)
#   - Bidirectional (non-causal) LLaMA language model
#   - Pooler output instead of generative logits
# --------------------------------------------------------


406
class LlamaNemotronVLEmbedProcessingInfo(BaseInternVLProcessingInfo):
407
408
    """Processing info for LlamaNemotronVL embedding model."""

409
    def get_image_processor(self, **kwargs):
410
        model_config = self.ctx.model_config
411
412
413
414
415
416
417

        config = self.get_hf_config()
        processor_config = (
            get_hf_file_to_dict(
                "processor_config.json",
                model_config.model,
                model_config.revision,
418
            )
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            or {}
        )

        min_dynamic_patch = processor_config.get(
            "min_input_tiles",
            getattr(config, "min_dynamic_patch", 1),
        )
        max_dynamic_patch = processor_config.get(
            "max_input_tiles",
            getattr(config, "max_dynamic_patch", 1),
        )
        dynamic_image_size = processor_config.get(
            "dynamic_image_size",
            getattr(config, "dynamic_image_size", True),
        )

        kwargs = self.ctx.get_merged_mm_kwargs(kwargs)
        kwargs.setdefault("image_size", config.force_image_size)
        kwargs.setdefault("min_dynamic_patch", min_dynamic_patch)
        kwargs.setdefault("max_dynamic_patch", max_dynamic_patch)
        kwargs.setdefault("dynamic_image_size", dynamic_image_size)
        kwargs.setdefault("use_thumbnail", True)

        return LlamaNemotronVLEmbedImageProcessor(**kwargs)

    def get_hf_processor(self, **kwargs: object) -> LlamaNemotronVLEmbedProcessor:
        config = self.get_hf_config()
        vision_config = config.vision_config

        image_processor = self.get_image_processor(**kwargs)
        image_size = image_processor.image_size
        patch_size = vision_config.patch_size
        downsample_ratio = config.downsample_ratio
        image_seq_length = int((image_size // patch_size) ** 2 * (downsample_ratio**2))
453

454
        return LlamaNemotronVLEmbedProcessor(
455
            tokenizer=self.get_tokenizer(),
456
457
            image_processor=image_processor,
            image_seq_length=image_seq_length,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        )


@MULTIMODAL_REGISTRY.register_processor(
    BaseInternVLMultiModalProcessor[LlamaNemotronVLEmbedProcessingInfo],
    info=LlamaNemotronVLEmbedProcessingInfo,
    dummy_inputs=BaseInternVLDummyInputsBuilder[LlamaNemotronVLEmbedProcessingInfo],
)
class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling):
    """
    LlamaNemotronVL model for embeddings.

    Inherits from LlamaNemotronVLChatModel and specializes it for embedding tasks:
    - Uses SigLIP vision encoder instead of C-RADIO
    - Uses bidirectional LLaMA (via llm_config) instead of causal LLaMA
    - Adds pooler for embedding output instead of generating logits
    """

    is_pooling_model = True

    # Weight mapping from checkpoint format to vLLM format
    # Different from parent class due to different vision model structure
    weight_mapper = WeightsMapper(
        orig_to_new_prefix={
            # Language model mapping
            "language_model.layers.": "language_model.model.layers.",
            "language_model.embed_tokens.": "language_model.model.embed_tokens.",
            "language_model.norm.": "language_model.model.norm.",
            # Vision model mapping (SiglipVisionModel has nested vision_model)
            "vision_model.encoder.": "vision_model.vision_model.encoder.",
            "vision_model.embeddings.": "vision_model.vision_model.embeddings.",
            "vision_model.post_layernorm.": "vision_model.vision_model.post_layernorm.",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        config = vllm_config.model_config.hf_config

        # Override: get img_context_token_id from config (parent sets None)
        self.img_context_token_id = getattr(config, "img_context_token_id", None)

        # Initialize pooler for embedding output
        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler = DispatchPooler.for_embedding(pooler_config)

    def _init_vision_model(
        self,
        config: PretrainedConfig,
        quant_config,
        *,
        prefix: str,
    ) -> nn.Module:
        """Override to use SigLIP instead of C-RADIO."""
        return SiglipVisionModel(
            config.vision_config,
            quant_config=quant_config,
            prefix=prefix,
            use_head=False,
        )

    def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
        """Override to use different MLP structure for embedding model."""
        return super()._init_mlp1(
            config,
            vit_hidden_size=config.vision_config.hidden_size,
            vision_projection_hidden_size=config.get_text_config().hidden_size,
        )

    def _call_vision_model(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Override to handle SigLIP interface."""
        return self.vision_model(pixel_values)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Override to use different weight mapping for SigLIP."""
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.weight_mapper)
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590


class LlamaNemotronVLForSequenceClassification(
    LlamaNemotronVLForEmbedding, SupportsCrossEncoding
):
    """LlamaNemotronVL model variant for sequence classification / reranking."""

    # Reranker checkpoint places base model weights under `model.*`,
    # while `score.*` remains at the top level.
    weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
        LlamaNemotronVLForEmbedding.weight_mapper
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        text_config = vllm_config.model_config.hf_config.get_text_config()
        model_config = vllm_config.model_config
        quant_config = vllm_config.quant_config

        self.score = ReplicatedLinear(
            model_config.get_hidden_size(),
            text_config.num_labels,
            bias=False,
            params_dtype=model_config.head_dtype,
            quant_config=quant_config,
            return_bias=False,
            prefix=maybe_prefix(prefix, "score"),
        )

        pooler_config = model_config.pooler_config
        assert pooler_config is not None
        self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loaded_weights = super().load_weights(weights)

        # reranker checkpoint omits the inner LM seq-cls head
        # (`language_model.score.*`). It is unused by this outer model, but
        # the default loader expects all parameters to be initialized.
        for name, param in self.named_parameters():
            if not name.startswith("language_model.score.") or name in loaded_weights:
                continue

            if name.endswith(".weight"):
                torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
            elif name.endswith(".bias"):
                torch.nn.init.zeros_(param)
            else:
                torch.nn.init.normal_(param, mean=0.0, std=0.02)

            loaded_weights.add(name)

        return loaded_weights