radio.py 18.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import math
from collections.abc import Iterable
from itertools import repeat
14
from typing import TypeAlias
15
16
17
18
19
20
21
22
23
24
25

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionEncoder

26
27
input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


def _ntuple(n):
    def parse(x):
        if isinstance(x, Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


class ClsToken(nn.Module):
    def __init__(
        self,
        ndim: int,
        num_tokens: int = 1,
        enabled: bool = True,
52
53
        register_multiple: int | None = None,
        num_registers: int | None = None,
54
55
56
57
58
59
60
61
62
63
64
    ):
        super().__init__()

        self.ndim = ndim
        self.enabled = enabled
        self.num_registers = 0
        self.num_tokens = num_tokens
        if enabled:
            if num_registers:
                self.num_registers = num_registers
            elif register_multiple:
65
66
67
                self.num_registers = register_multiple - (
                    num_tokens % register_multiple
                )
68
69
70

            scale = ndim**-0.5
            self.token = nn.Parameter(
71
72
                torch.randn(num_tokens + self.num_registers, ndim) * scale
            )
73
74
75
76
77
78
79
80
81
82
83

        else:
            self.token = None

        self.num_patches = self.num_tokens + self.num_registers

    def forward(self, x: torch.Tensor):
        if self.token is None:
            return x

        token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
84
85
86
87
88
89
90
        x = torch.cat(
            [
                token,
                x,
            ],
            dim=1,
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104

        return x


class ViTPatchGenerator(nn.Module):
    def __init__(
        self,
        #  config: PretrainedConfig,
        patch_size: int,
        embed_dim: int,
        input_dims: input_dim_t,
        abs_pos: bool = True,
        normalize_patches: bool = False,
        cls_token: bool = False,
105
        max_input_dims: input_dim_t | None = None,
106
107
108
        pos_dropout: float = 0.0,
        return_pos_enc: bool = False,
        num_cls_tokens: int = 1,
109
110
        register_multiple: int | None = None,
        num_registers: int | None = None,
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        patch_bias: bool = False,
        device=None,
        dtype=None,
    ):
        super().__init__()
        if isinstance(input_dims, int):
            input_dims = (input_dims, input_dims)

        if max_input_dims is None:
            max_input_dims = input_dims
        if isinstance(max_input_dims, int):
            max_input_dims = (max_input_dims, max_input_dims)

        max_input_dims = tuple(
125
126
            int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
        )
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

        self.cpe_mode = max_input_dims != input_dims
        self.pos_dropout = pos_dropout
        self.return_pos_enc = return_pos_enc

        factory = dict(device=device, dtype=dtype)

        self.patch_size = patch_size
        self.abs_pos = abs_pos
        self.embed_dim = embed_dim

        self.num_rows = max_input_dims[0] // patch_size
        self.num_cols = max_input_dims[1] // patch_size
        self.input_dims = tuple(d // patch_size for d in input_dims)
        self.num_patches = self.num_rows * self.num_cols
        self.max_input_dims = max_input_dims

        self.im_to_patches = Im2Patches(patch_size)
145
146
147
        self.embedder = ViTPatchLinear(
            patch_size, embed_dim, bias=patch_bias, **factory
        )
148
149
150
151

        if abs_pos:
            scale = embed_dim**-0.5
            self.pos_embed = nn.Parameter(
152
153
                torch.randn(1, self.num_patches, embed_dim, **factory) * scale
            )
154
155
156
157
158
159
160
161
162

        self.cls_token = ClsToken(
            embed_dim,
            num_tokens=num_cls_tokens,
            enabled=cls_token,
            register_multiple=register_multiple,
            num_registers=num_registers,
        )

163
164
165
        self.patch_normalizer = (
            nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
        )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        patches = self.embed_patches(x)
        patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
        patches = self.cls_token(patches)
        patches = self.patch_normalizer(patches)
        if self.return_pos_enc:
            return patches, pos_enc
        return patches

    @property
    def apply_cls_token(self):
        return self.cls_token.enabled

    @property
    def num_cls_tokens(self):
        return self.cls_token.num_tokens

    @property
    def num_cls_patches(self):
        return self.cls_token.num_patches

    @property
    def num_registers(self):
        return self.cls_token.num_registers

    @property
    def num_skip(self):
        return self.num_cls_tokens + self.num_registers

    def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
        if src_embed.shape != targ_embed.shape:
            src_size = int(math.sqrt(src_embed.shape[1]))

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            assert src_size**2 == src_embed.shape[1], (
                "Unable to interpolate non-square embedding"
            )

            src_embed = rearrange(
                src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
            )
            src_embed = F.interpolate(
                src_embed,
                size=(self.num_rows, self.num_cols),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
215
216
        targ_embed.data.copy_(src_embed)

217
218
219
    def _load_projection(
        self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
    ):
220
221
222
        if src_proj_weight.shape != targ_proj_weight.shape:
            src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
                "Unable to interpolate non-square patch size"
            )

            src_proj_weight = rearrange(
                src_proj_weight,
                "b (c h w) -> b c h w",
                c=3,
                h=src_patch_size,
                w=src_patch_size,
            )
            src_proj_weight = F.interpolate(
                src_proj_weight,
                size=(self.patch_size, self.patch_size),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
242
243
244
245
246
247
248
249
250
251
        targ_proj_weight.data.copy_(src_proj_weight)

    def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
        patches = self.im_to_patches(x)
        patches = self.embedder(patches)
        return patches

    def apply_pos_enc(
        self,
        patches: torch.Tensor,
252
253
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
254
255
256
257
258
259
260
    ) -> torch.Tensor:
        if not self.abs_pos:
            return patches

        pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)

        if self.training and self.pos_dropout > 0:
261
262
263
264
265
266
            keeps = (
                torch.rand(
                    patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
                )
                > self.pos_dropout
            )
267
268
269
270
271
272
273
274
275
            pos_enc_drop = torch.where(keeps, pos_enc, 0)
        else:
            pos_enc_drop = pos_enc

        return patches + pos_enc_drop, pos_enc

    def get_pos_enc(
        self,
        batch_size: int,
276
277
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
278
279
280
281
282
283
284
285
286
287
288
    ) -> torch.Tensor:
        if input_size is None:
            input_dims = self.input_dims
        else:
            input_dims = tuple(d // self.patch_size for d in input_size)

        pos_embed = self._get_pos_embeddings(batch_size, input_dims)

        if patch_idxs is None:
            return pos_embed

289
        exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
290

291
292
293
        pos_embed = torch.gather(
            pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
        )
294
295
        return pos_embed

296
    def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
297
298
299
        if (self.num_rows, self.num_cols) == input_dims:
            return self.pos_embed

300
301
302
        pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
            0, 3, 1, 2
        )
303
304
305

        def window_select(pos_embed):
            if input_dims[0] < pos_embed.shape[-2]:
306
                pos_embed = pos_embed[..., : input_dims[0], :]
307
            if input_dims[1] < pos_embed.shape[-1]:
308
                pos_embed = pos_embed[..., :, : input_dims[1]]
309
310
311
312
313
            return pos_embed

        if self.cpe_mode:
            if self.training:
                min_scale = math.sqrt(0.1)
314
315
316
317
318
                scale = (
                    torch.rand(batch_size, 1, 1, device=pos_embed.device)
                    * (1 - min_scale)
                    + min_scale
                )
319
320
321
                aspect_min = math.log(3 / 4)
                aspect_max = -aspect_min
                aspect = torch.exp(
322
323
324
325
                    torch.rand(batch_size, 1, 1, device=pos_embed.device)
                    * (aspect_max - aspect_min)
                    + aspect_min
                )
326
327
328
329
330

                scale_x = scale * aspect
                scale_y = scale * (1 / aspect)
                scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)

331
332
333
                pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
                    1 - scale_xy
                )
334
335

                lin_x = torch.linspace(
336
337
                    0, 1, steps=input_dims[1], device=pos_embed.device
                )[None, None].expand(batch_size, input_dims[0], -1)
338
                lin_y = torch.linspace(
339
340
                    0, 1, steps=input_dims[0], device=pos_embed.device
                )[None, :, None].expand(batch_size, -1, input_dims[1])
341
342
343
344
345
346
347
348
349
350
351

                lin_xy = torch.stack([lin_x, lin_y], dim=-1)

                grid_xy = lin_xy * scale_xy + pos_xy

                # Convert to [-1, 1] range
                grid_xy.mul_(2).sub_(1)

                pos_embed = F.grid_sample(
                    pos_embed.float().expand(batch_size, -1, -1, -1),
                    grid=grid_xy,
352
353
                    mode="bilinear",
                    padding_mode="zeros",
354
355
356
357
                    align_corners=True,
                ).to(pos_embed.dtype)
            else:
                max_dim = max(input_dims)
358
359
360
361
362
363
                pos_embed = F.interpolate(
                    pos_embed.float(),
                    size=(max_dim, max_dim),
                    align_corners=True,
                    mode="bilinear",
                ).to(pos_embed.dtype)
364
365
366
367
368
369

                pos_embed = window_select(pos_embed)
        else:
            pos_embed = window_select(pos_embed)

        if pos_embed.shape[-2:] != input_dims:
370
371
372
            pos_embed = F.interpolate(
                pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
            ).to(pos_embed.dtype)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393

        pos_embed = pos_embed.flatten(2).permute(0, 2, 1)

        return pos_embed


class Im2Patches(nn.Module):
    def __init__(self, patch_size: int):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.patch_size == 1:
            patches = x.flatten(2)
            patches = patches.permute(0, 2, 1)
            return patches

        py = x.shape[-2] // self.patch_size
        px = x.shape[-1] // self.patch_size
        patches = rearrange(
            x,
394
            "b c (py yy) (px xx) -> b (py px) (c yy xx)",
395
396
397
398
399
400
401
402
403
            py=py,
            yy=self.patch_size,
            px=px,
            xx=self.patch_size,
        )
        return patches


class ViTPatchLinear(nn.Linear):
404
    def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
405
406
407
408
409
410
411
412
413
414
415
416
        super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
        self.patch_size = patch_size


class RadioInternVisionModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig = None,
417
        quant_config: QuantizationConfig | None = None,
418
        *,
419
        num_hidden_layers_override: int | None = None,
420
421
422
423
424
425
426
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.img_size, self.grid_size, self.num_patches = self._init_img_size(
427
428
            to_2tuple(config.patch_size), config.image_size
        )
429
        max_img_size = int(
430
            round(config.cpe_max_size / config.patch_size) * config.patch_size
431
        )
432
        unique_teachers = set(t["name"] for t in config.teachers)
433
434
435
436
437
438
        self.patch_generator = ViTPatchGenerator(
            config.patch_size,
            config.hidden_size,
            input_dims=self.img_size,
            max_input_dims=max_img_size,
            cls_token=True,
439
440
            num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
            register_multiple=config.register_multiple,
441
        )
442
443
444
445
446
447
448
449
450

        self.encoder = InternVisionEncoder(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
            prefix=f"{prefix}.encoder",
        )

451
    def _init_img_size(self, patch_size, img_size: int | tuple[int, int]):
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        if img_size is None:
            return None, None, None
        img_size = to_2tuple(img_size)
        grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
        num_patches = grid_size[0] * grid_size[1]
        return img_size, grid_size, num_patches

    def get_input_embeddings(self):
        return self.embeddings

    def forward(self, x: torch.Tensor) -> torch.FloatTensor:
        assert self.patch_generator is not None
        hidden_states = self.patch_generator(x)
        encoder_outputs = self.encoder(inputs_embeds=hidden_states)
        return encoder_outputs


class RadioModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig,
477
        quant_config: QuantizationConfig | None = None,
478
        *,
479
        num_hidden_layers_override: int | None = None,
480
481
482
483
484
485
486
487
488
489
490
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.model = RadioInternVisionModel(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
491
492
            prefix=prefix,
        )
493

494
495
496
497
498
499
500
501
502
        summary_idxs = None
        if config.teachers:
            summary_idxs = torch.tensor(
                [i for i, t in enumerate(config.teachers) if t.get("use_summary", True)]
            )
            if summary_idxs.numel() > 0:
                self.register_buffer("summary_idxs", summary_idxs)
        self.summary_idxs = summary_idxs

503
504
    def forward(
        self,
505
506
        pixel_values: torch.Tensor | None = None,
        pixel_embeds: torch.Tensor | None = None,
507
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
508
        y = self.model(pixel_values)
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        return self._extract_final(y)

    def load_weights(self, weights) -> set[str]:
        loaded_params: set[str] = set()
        params_dict = dict(self.named_parameters())

        if isinstance(weights, dict):
            weights_list = list(weights.items())
        else:
            weights_list = list(weights)

        for name, weight in weights_list:
            if not name.startswith("radio_model."):
                # Skip non-radio weights
                continue

525
            sub = name[len("radio_model.") :]  # drop "radio_model." prefix
526
527
528
529

            # Skip buffers not used in vLLM
            if sub in {"summary_idxs"}:
                continue
530
531
532
533
            if sub.startswith("input_conditioner."):
                # we normalize in the input processor,
                # based on norm and std values from the config
                continue
534
535
536
537
538
539
540
541
542
543
544
545
546
547

            vllm_key = None
            if sub.startswith("model.patch_generator."):
                vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}"
            elif sub.startswith("input_conditioner."):
                vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}"
            elif sub.startswith("model.blocks."):
                # Encoder blocks: HF 'model.blocks.{i}.' ->
                # vLLM 'model.encoder.layers.{i}.'
                parts = sub.split(".")
                if len(parts) >= 4:
                    layer_idx = parts[2]
                    suffix = ".".join(parts[3:])
                    # Skip layer-scale entries that vLLM doesn't use
548
                    if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")):
549
550
551
552
553
                        continue
                    vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"

            if vllm_key and vllm_key in params_dict:
                param = params_dict[vllm_key]
554
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
555
556
557
558
559
                weight_loader(param, weight)
                loaded_params.add(vllm_key)

        return loaded_params

560
561
562
    def _extract_final(
        self, y: torch.Tensor
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
563
564
565
        # Remove CLS + REGISTERS tokens
        patch_gen = getattr(self.model, "patch_generator", None)
        if patch_gen is not None:
566
567
568
569
570
            all_summary = y[:, : patch_gen.num_cls_tokens]
            if self.summary_idxs is not None:
                bb_summary = all_summary[:, self.summary_idxs]
            else:
                bb_summary = all_summary
571
            all_feat = y[:, patch_gen.num_skip :]
572

573
        return bb_summary.flatten(1), all_feat