radio.py 24.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import math
from collections.abc import Iterable
from itertools import repeat
14
from typing import TypeAlias
15
16
17
18
19
20
21
22
23

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
25
26
27
28
from vllm.model_executor.models.intern_vit import (
    InternParallelAttention,
    InternVisionEncoder,
    InternVisionEncoderLayer,
)
29

30
31
input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


def _ntuple(n):
    def parse(x):
        if isinstance(x, Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


50
51
52
53
54
55
56
57
58
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
    h, w = size
    return (h // patch_size) * (w // patch_size)


def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
    return [calc_seq_len(size, patch_size) for size in sizes]


59
60
61
62
63
64
class ClsToken(nn.Module):
    def __init__(
        self,
        ndim: int,
        num_tokens: int = 1,
        enabled: bool = True,
65
66
        register_multiple: int | None = None,
        num_registers: int | None = None,
67
68
69
70
71
72
73
74
75
76
77
    ):
        super().__init__()

        self.ndim = ndim
        self.enabled = enabled
        self.num_registers = 0
        self.num_tokens = num_tokens
        if enabled:
            if num_registers:
                self.num_registers = num_registers
            elif register_multiple:
78
79
80
                self.num_registers = register_multiple - (
                    num_tokens % register_multiple
                )
81
82
83

            scale = ndim**-0.5
            self.token = nn.Parameter(
84
85
                torch.randn(num_tokens + self.num_registers, ndim) * scale
            )
86
87
88
89
90
91
92
93
94
95
96

        else:
            self.token = None

        self.num_patches = self.num_tokens + self.num_registers

    def forward(self, x: torch.Tensor):
        if self.token is None:
            return x

        token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
97
98
99
100
101
102
103
        x = torch.cat(
            [
                token,
                x,
            ],
            dim=1,
        )
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        return x


class ViTPatchGenerator(nn.Module):
    def __init__(
        self,
        #  config: PretrainedConfig,
        patch_size: int,
        embed_dim: int,
        input_dims: input_dim_t,
        abs_pos: bool = True,
        normalize_patches: bool = False,
        cls_token: bool = False,
118
        max_input_dims: input_dim_t | None = None,
119
120
121
        pos_dropout: float = 0.0,
        return_pos_enc: bool = False,
        num_cls_tokens: int = 1,
122
123
        register_multiple: int | None = None,
        num_registers: int | None = None,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        patch_bias: bool = False,
        device=None,
        dtype=None,
    ):
        super().__init__()
        if isinstance(input_dims, int):
            input_dims = (input_dims, input_dims)

        if max_input_dims is None:
            max_input_dims = input_dims
        if isinstance(max_input_dims, int):
            max_input_dims = (max_input_dims, max_input_dims)

        max_input_dims = tuple(
138
139
            int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
        )
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

        self.cpe_mode = max_input_dims != input_dims
        self.pos_dropout = pos_dropout
        self.return_pos_enc = return_pos_enc

        factory = dict(device=device, dtype=dtype)

        self.patch_size = patch_size
        self.abs_pos = abs_pos
        self.embed_dim = embed_dim

        self.num_rows = max_input_dims[0] // patch_size
        self.num_cols = max_input_dims[1] // patch_size
        self.input_dims = tuple(d // patch_size for d in input_dims)
        self.num_patches = self.num_rows * self.num_cols
        self.max_input_dims = max_input_dims

        self.im_to_patches = Im2Patches(patch_size)
158
159
160
        self.embedder = ViTPatchLinear(
            patch_size, embed_dim, bias=patch_bias, **factory
        )
161
162
163
164

        if abs_pos:
            scale = embed_dim**-0.5
            self.pos_embed = nn.Parameter(
165
166
                torch.randn(1, self.num_patches, embed_dim, **factory) * scale
            )
167
168
169
170
171
172
173
174
175

        self.cls_token = ClsToken(
            embed_dim,
            num_tokens=num_cls_tokens,
            enabled=cls_token,
            register_multiple=register_multiple,
            num_registers=num_registers,
        )

176
177
178
        self.patch_normalizer = (
            nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
        )
179

180
181
182
183
184
185
186
187
188
189
190
191
192
    def forward(
        self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
    ) -> torch.Tensor:
        if imgs_sizes is not None:
            patches = self.embedder(x)
            patches, pos_enc = self.apply_pos_enc_dynamic(
                patches, imgs_sizes=imgs_sizes
            )
            patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
        else:
            patches = self.embed_patches(x)
            patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
            patches = self.cls_token(patches)
193
194
195
196
197
        patches = self.patch_normalizer(patches)
        if self.return_pos_enc:
            return patches, pos_enc
        return patches

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    def apply_pos_enc_dynamic(
        self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        if not self.abs_pos:
            return patches, None

        current_length = 0
        pos_enc_list = []

        for size in imgs_sizes:
            seq_length = calc_seq_len(size, self.patch_size)

            img_patches = patches[:, current_length : current_length + seq_length, :]
            pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
            img_patches_with_pos = img_patches + pos_enc

            patches = torch.cat(
                [
                    patches[:, :current_length, :],
                    img_patches_with_pos,
                    patches[:, current_length + seq_length :, :],
                ],
                dim=1,
            )
            pos_enc_list.append(pos_enc)
            current_length += seq_length

        full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
        return patches, full_pos_enc

    def cls_token_dynamic(
        self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
    ) -> torch.Tensor:
        if not self.cls_token.enabled:
            return patches

        out = []
        current_length = 0

        for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
            class_token = self.cls_token.token.unsqueeze(0).expand(
                patches.shape[0], -1, -1
            )
            out.append(class_token)
            out.append(patches[:, current_length : current_length + seq_len, :])
            current_length += seq_len

        return torch.cat(out, dim=1)

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    @property
    def apply_cls_token(self):
        return self.cls_token.enabled

    @property
    def num_cls_tokens(self):
        return self.cls_token.num_tokens

    @property
    def num_cls_patches(self):
        return self.cls_token.num_patches

    @property
    def num_registers(self):
        return self.cls_token.num_registers

    @property
    def num_skip(self):
        return self.num_cls_tokens + self.num_registers

    def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
        if src_embed.shape != targ_embed.shape:
            src_size = int(math.sqrt(src_embed.shape[1]))

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            assert src_size**2 == src_embed.shape[1], (
                "Unable to interpolate non-square embedding"
            )

            src_embed = rearrange(
                src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
            )
            src_embed = F.interpolate(
                src_embed,
                size=(self.num_rows, self.num_cols),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
286
287
        targ_embed.data.copy_(src_embed)

288
289
290
    def _load_projection(
        self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
    ):
291
292
293
        if src_proj_weight.shape != targ_proj_weight.shape:
            src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
                "Unable to interpolate non-square patch size"
            )

            src_proj_weight = rearrange(
                src_proj_weight,
                "b (c h w) -> b c h w",
                c=3,
                h=src_patch_size,
                w=src_patch_size,
            )
            src_proj_weight = F.interpolate(
                src_proj_weight,
                size=(self.patch_size, self.patch_size),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
313
314
315
316
317
318
319
320
321
322
        targ_proj_weight.data.copy_(src_proj_weight)

    def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
        patches = self.im_to_patches(x)
        patches = self.embedder(patches)
        return patches

    def apply_pos_enc(
        self,
        patches: torch.Tensor,
323
324
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
325
326
327
328
329
330
331
    ) -> torch.Tensor:
        if not self.abs_pos:
            return patches

        pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)

        if self.training and self.pos_dropout > 0:
332
333
334
335
336
337
            keeps = (
                torch.rand(
                    patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
                )
                > self.pos_dropout
            )
338
339
340
341
342
343
344
345
346
            pos_enc_drop = torch.where(keeps, pos_enc, 0)
        else:
            pos_enc_drop = pos_enc

        return patches + pos_enc_drop, pos_enc

    def get_pos_enc(
        self,
        batch_size: int,
347
348
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
349
350
351
352
353
354
355
356
357
358
359
    ) -> torch.Tensor:
        if input_size is None:
            input_dims = self.input_dims
        else:
            input_dims = tuple(d // self.patch_size for d in input_size)

        pos_embed = self._get_pos_embeddings(batch_size, input_dims)

        if patch_idxs is None:
            return pos_embed

360
        exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
361

362
363
364
        pos_embed = torch.gather(
            pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
        )
365
366
        return pos_embed

367
    def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
368
369
370
        if (self.num_rows, self.num_cols) == input_dims:
            return self.pos_embed

371
372
373
        pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
            0, 3, 1, 2
        )
374
375
376

        def window_select(pos_embed):
            if input_dims[0] < pos_embed.shape[-2]:
377
                pos_embed = pos_embed[..., : input_dims[0], :]
378
            if input_dims[1] < pos_embed.shape[-1]:
379
                pos_embed = pos_embed[..., :, : input_dims[1]]
380
381
382
383
384
            return pos_embed

        if self.cpe_mode:
            if self.training:
                min_scale = math.sqrt(0.1)
385
386
387
388
389
                scale = (
                    torch.rand(batch_size, 1, 1, device=pos_embed.device)
                    * (1 - min_scale)
                    + min_scale
                )
390
391
392
                aspect_min = math.log(3 / 4)
                aspect_max = -aspect_min
                aspect = torch.exp(
393
394
395
396
                    torch.rand(batch_size, 1, 1, device=pos_embed.device)
                    * (aspect_max - aspect_min)
                    + aspect_min
                )
397
398
399
400
401

                scale_x = scale * aspect
                scale_y = scale * (1 / aspect)
                scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)

402
403
404
                pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
                    1 - scale_xy
                )
405
406

                lin_x = torch.linspace(
407
408
                    0, 1, steps=input_dims[1], device=pos_embed.device
                )[None, None].expand(batch_size, input_dims[0], -1)
409
                lin_y = torch.linspace(
410
411
                    0, 1, steps=input_dims[0], device=pos_embed.device
                )[None, :, None].expand(batch_size, -1, input_dims[1])
412
413
414
415
416
417
418
419
420
421
422

                lin_xy = torch.stack([lin_x, lin_y], dim=-1)

                grid_xy = lin_xy * scale_xy + pos_xy

                # Convert to [-1, 1] range
                grid_xy.mul_(2).sub_(1)

                pos_embed = F.grid_sample(
                    pos_embed.float().expand(batch_size, -1, -1, -1),
                    grid=grid_xy,
423
424
                    mode="bilinear",
                    padding_mode="zeros",
425
426
427
428
                    align_corners=True,
                ).to(pos_embed.dtype)
            else:
                max_dim = max(input_dims)
429
430
431
432
433
434
                pos_embed = F.interpolate(
                    pos_embed.float(),
                    size=(max_dim, max_dim),
                    align_corners=True,
                    mode="bilinear",
                ).to(pos_embed.dtype)
435
436
437
438
439
440

                pos_embed = window_select(pos_embed)
        else:
            pos_embed = window_select(pos_embed)

        if pos_embed.shape[-2:] != input_dims:
441
442
443
            pos_embed = F.interpolate(
                pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
            ).to(pos_embed.dtype)
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

        pos_embed = pos_embed.flatten(2).permute(0, 2, 1)

        return pos_embed


class Im2Patches(nn.Module):
    def __init__(self, patch_size: int):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.patch_size == 1:
            patches = x.flatten(2)
            patches = patches.permute(0, 2, 1)
            return patches

        py = x.shape[-2] // self.patch_size
        px = x.shape[-1] // self.patch_size
        patches = rearrange(
            x,
465
            "b c (py yy) (px xx) -> b (py px) (c yy xx)",
466
467
468
469
470
471
472
473
474
            py=py,
            yy=self.patch_size,
            px=px,
            xx=self.patch_size,
        )
        return patches


class ViTPatchLinear(nn.Linear):
475
    def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
476
477
478
479
        super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
        self.patch_size = patch_size


480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
class RadioParallelAttention(InternParallelAttention):
    def forward(
        self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
    ) -> torch.Tensor:
        if attn_mask is None:
            return super().forward(x)

        B, N, _ = x.shape
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        if self.qk_normalization:
            q, k = self._apply_qk_norm(q, k)

        q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
        k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
        v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
        q, k, v = (t.transpose(1, 2) for t in (q, k, v))
        out = F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, scale=self.scale
        )
        out = out.transpose(1, 2).reshape(B, N, -1)
        out, _ = self.proj(out)
        return out


class RadioVisionEncoderLayer(InternVisionEncoderLayer):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attn_mask: torch.Tensor | None = None,
    ):
        hidden_states = (
            hidden_states
            + self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
        )

        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2

        return hidden_states


class RadioVisionEncoder(InternVisionEncoder):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
        attn_mask: torch.Tensor | None = None,
    ):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
        return hidden_states


540
541
542
543
544
545
546
547
class RadioInternVisionModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig = None,
548
        quant_config: QuantizationConfig | None = None,
549
        *,
550
        num_hidden_layers_override: int | None = None,
551
552
553
554
555
556
557
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.img_size, self.grid_size, self.num_patches = self._init_img_size(
558
559
            to_2tuple(config.patch_size), config.image_size
        )
560
        max_img_size = int(
561
            round(config.cpe_max_size / config.patch_size) * config.patch_size
562
        )
563
        unique_teachers = set(t["name"] for t in config.teachers)
564
565
566
567
568
569
        self.patch_generator = ViTPatchGenerator(
            config.patch_size,
            config.hidden_size,
            input_dims=self.img_size,
            max_input_dims=max_img_size,
            cls_token=True,
570
571
            num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
            register_multiple=config.register_multiple,
572
        )
573

574
        self.encoder = RadioVisionEncoder(
575
576
577
578
579
580
581
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
            prefix=f"{prefix}.encoder",
        )

582
    def _init_img_size(self, patch_size, img_size: int | tuple[int, int]):
583
584
585
586
587
588
589
590
591
592
        if img_size is None:
            return None, None, None
        img_size = to_2tuple(img_size)
        grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
        num_patches = grid_size[0] * grid_size[1]
        return img_size, grid_size, num_patches

    def get_input_embeddings(self):
        return self.embeddings

593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
    def create_inter_image_attention_mask(
        self, imgs_sizes: list[tuple[int, int]], device: torch.device
    ) -> torch.Tensor:
        patch_size = self.patch_generator.patch_size
        num_skip = self.patch_generator.num_skip

        seq_lens = calc_seq_lens(imgs_sizes, patch_size)
        patch_counts = [seq_len + num_skip for seq_len in seq_lens]
        total_patches = sum(patch_counts)

        # Create attention mask - default to False (mask out)
        mask = torch.zeros(
            total_patches, total_patches, dtype=torch.bool, device=device
        )

        # Each image's patches can only attend to patches from the same image
        start_idx = 0
        for patch_count in patch_counts:
            end_idx = start_idx + patch_count
            # Allow attention within this image's patches
            mask[start_idx:end_idx, start_idx:end_idx] = True
            start_idx = end_idx

        return mask

    def forward(
        self,
        x: torch.Tensor,
        imgs_sizes: torch.Tensor | None = None,
    ) -> torch.FloatTensor:
        hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
        attn_mask = None
        if imgs_sizes is not None and len(imgs_sizes) > 1:
            # Dynamic Resolution
            attn_mask = self.create_inter_image_attention_mask(
                imgs_sizes, device=x.device
            )
        encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
631
632
633
634
635
636
637
638
639
640
641
        return encoder_outputs


class RadioModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig,
642
        quant_config: QuantizationConfig | None = None,
643
        *,
644
        num_hidden_layers_override: int | None = None,
645
646
647
648
649
650
651
652
653
654
655
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.model = RadioInternVisionModel(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
656
657
            prefix=prefix,
        )
658

659
660
661
662
663
664
665
666
667
        summary_idxs = None
        if config.teachers:
            summary_idxs = torch.tensor(
                [i for i, t in enumerate(config.teachers) if t.get("use_summary", True)]
            )
            if summary_idxs.numel() > 0:
                self.register_buffer("summary_idxs", summary_idxs)
        self.summary_idxs = summary_idxs

668
669
    def forward(
        self,
670
671
        pixel_values: torch.Tensor | None = None,
        pixel_embeds: torch.Tensor | None = None,
672
673
        *,
        imgs_sizes: torch.Tensor | None = None,
674
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
675
676
        y = self.model(pixel_values, imgs_sizes=imgs_sizes)
        return self._extract_final(y, imgs_sizes=imgs_sizes)
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691

    def load_weights(self, weights) -> set[str]:
        loaded_params: set[str] = set()
        params_dict = dict(self.named_parameters())

        if isinstance(weights, dict):
            weights_list = list(weights.items())
        else:
            weights_list = list(weights)

        for name, weight in weights_list:
            if not name.startswith("radio_model."):
                # Skip non-radio weights
                continue

692
            sub = name[len("radio_model.") :]  # drop "radio_model." prefix
693
694
695
696

            # Skip buffers not used in vLLM
            if sub in {"summary_idxs"}:
                continue
697
698
699
700
            if sub.startswith("input_conditioner."):
                # we normalize in the input processor,
                # based on norm and std values from the config
                continue
701
702
703
704
705
706
707
708
709
710
711
712
713
714

            vllm_key = None
            if sub.startswith("model.patch_generator."):
                vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}"
            elif sub.startswith("input_conditioner."):
                vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}"
            elif sub.startswith("model.blocks."):
                # Encoder blocks: HF 'model.blocks.{i}.' ->
                # vLLM 'model.encoder.layers.{i}.'
                parts = sub.split(".")
                if len(parts) >= 4:
                    layer_idx = parts[2]
                    suffix = ".".join(parts[3:])
                    # Skip layer-scale entries that vLLM doesn't use
715
                    if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")):
716
717
718
719
720
                        continue
                    vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"

            if vllm_key and vllm_key in params_dict:
                param = params_dict[vllm_key]
721
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
722
723
724
725
726
                weight_loader(param, weight)
                loaded_params.add(vllm_key)

        return loaded_params

727
    def _extract_final(
728
        self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
729
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
730
        # Remove CLS + REGISTERS tokens
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
        num_skip = self.model.patch_generator.num_skip
        patch_size = self.model.patch_generator.patch_size
        num_cls_tokens = self.model.patch_generator.num_cls_tokens
        if imgs_sizes is None:
            all_summary = y[:, :num_cls_tokens]
            all_feat = y[:, num_skip:]
        else:
            all_patches = []
            summaries = []
            current_pos = 0
            for num_patches in calc_seq_lens(imgs_sizes, patch_size):
                patches = y[
                    :, current_pos + num_skip : current_pos + num_skip + num_patches, :
                ]
                all_patches.append(patches)
                summary = y[:, current_pos : current_pos + num_cls_tokens, :]
                summaries.append(summary)
                current_pos += num_skip + num_patches
            all_summary = torch.cat(summaries, dim=1)
            all_feat = torch.cat(all_patches, dim=1)

        if self.summary_idxs is not None:
            bb_summary = all_summary[:, self.summary_idxs]
        else:
            bb_summary = all_summary
756
        return bb_summary.flatten(1), all_feat