radio.py 27.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2023-2024, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import math
from collections.abc import Iterable
13
14
from dataclasses import dataclass
from itertools import accumulate, repeat
15
from typing import TypeAlias
16
17
18
19
20
21
22
23
24

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PretrainedConfig

from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
26
27
28
29
from vllm.model_executor.models.intern_vit import (
    InternParallelAttention,
    InternVisionEncoder,
    InternVisionEncoderLayer,
)
30

31
32
input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


def _ntuple(n):
    def parse(x):
        if isinstance(x, Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


51
52
53
54
55
56
57
58
59
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
    h, w = size
    return (h // patch_size) * (w // patch_size)


def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
    return [calc_seq_len(size, patch_size) for size in sizes]


60
61
62
63
64
65
class ClsToken(nn.Module):
    def __init__(
        self,
        ndim: int,
        num_tokens: int = 1,
        enabled: bool = True,
66
67
        register_multiple: int | None = None,
        num_registers: int | None = None,
68
69
70
71
72
73
74
75
76
77
78
    ):
        super().__init__()

        self.ndim = ndim
        self.enabled = enabled
        self.num_registers = 0
        self.num_tokens = num_tokens
        if enabled:
            if num_registers:
                self.num_registers = num_registers
            elif register_multiple:
79
80
81
                self.num_registers = register_multiple - (
                    num_tokens % register_multiple
                )
82
83
84

            scale = ndim**-0.5
            self.token = nn.Parameter(
85
86
                torch.randn(num_tokens + self.num_registers, ndim) * scale
            )
87
88
89
90
91
92
93
94
95
96
97

        else:
            self.token = None

        self.num_patches = self.num_tokens + self.num_registers

    def forward(self, x: torch.Tensor):
        if self.token is None:
            return x

        token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
98
99
100
101
102
103
104
        x = torch.cat(
            [
                token,
                x,
            ],
            dim=1,
        )
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        return x


class ViTPatchGenerator(nn.Module):
    def __init__(
        self,
        #  config: PretrainedConfig,
        patch_size: int,
        embed_dim: int,
        input_dims: input_dim_t,
        abs_pos: bool = True,
        normalize_patches: bool = False,
        cls_token: bool = False,
119
        max_input_dims: input_dim_t | None = None,
120
121
122
        pos_dropout: float = 0.0,
        return_pos_enc: bool = False,
        num_cls_tokens: int = 1,
123
124
        register_multiple: int | None = None,
        num_registers: int | None = None,
125
        patch_bias: bool = False,
126
127
        temporal_patch_size: int = 1,
        separate_video_embedder: bool = True,
128
129
130
131
132
133
134
135
136
137
138
139
140
        device=None,
        dtype=None,
    ):
        super().__init__()
        if isinstance(input_dims, int):
            input_dims = (input_dims, input_dims)

        if max_input_dims is None:
            max_input_dims = input_dims
        if isinstance(max_input_dims, int):
            max_input_dims = (max_input_dims, max_input_dims)

        max_input_dims = tuple(
141
142
            int(math.ceil(d / patch_size) * patch_size) for d in max_input_dims
        )
143
144
145
146
147
148
149
150
151
152

        self.cpe_mode = max_input_dims != input_dims
        self.pos_dropout = pos_dropout
        self.return_pos_enc = return_pos_enc

        factory = dict(device=device, dtype=dtype)

        self.patch_size = patch_size
        self.abs_pos = abs_pos
        self.embed_dim = embed_dim
153
        self.temporal_patch_size = temporal_patch_size
154
155
156
157
158
159
160
161

        self.num_rows = max_input_dims[0] // patch_size
        self.num_cols = max_input_dims[1] // patch_size
        self.input_dims = tuple(d // patch_size for d in input_dims)
        self.num_patches = self.num_rows * self.num_cols
        self.max_input_dims = max_input_dims

        self.im_to_patches = Im2Patches(patch_size)
162
163
164
        self.embedder = ViTPatchLinear(
            patch_size, embed_dim, bias=patch_bias, **factory
        )
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
        if temporal_patch_size > 1:
            if not separate_video_embedder:
                raise NotImplementedError(
                    "Only separate_video_embedder=True is supported for"
                    " temporal compression (temporal_patch_size > 1)"
                )
            self.video_embedder = ViTPatchLinear(
                patch_size,
                embed_dim,
                bias=patch_bias,
                temporal_patch_size=temporal_patch_size,
                **factory,
            )

180
181
182
        if abs_pos:
            scale = embed_dim**-0.5
            self.pos_embed = nn.Parameter(
183
184
                torch.randn(1, self.num_patches, embed_dim, **factory) * scale
            )
185
186
187
188
189
190
191
192
193

        self.cls_token = ClsToken(
            embed_dim,
            num_tokens=num_cls_tokens,
            enabled=cls_token,
            register_multiple=register_multiple,
            num_registers=num_registers,
        )

194
195
196
        self.patch_normalizer = (
            nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
        )
197

198
199
200
201
202
203
204
205
206
207
208
209
210
    def forward(
        self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
    ) -> torch.Tensor:
        if imgs_sizes is not None:
            patches = self.embedder(x)
            patches, pos_enc = self.apply_pos_enc_dynamic(
                patches, imgs_sizes=imgs_sizes
            )
            patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
        else:
            patches = self.embed_patches(x)
            patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
            patches = self.cls_token(patches)
211
212
213
214
215
        patches = self.patch_normalizer(patches)
        if self.return_pos_enc:
            return patches, pos_enc
        return patches

216
217
218
219
220
221
222
223
224
225
226
    def forward_video(self, x: torch.Tensor) -> torch.Tensor:
        """Process video frames with temporal compression.

        Groups T consecutive frames into tubelets before embedding.

        Args:
            x: [num_frames, 3, H, W] tensor of video frames

        Returns:
            Embedded patches with temporal compression applied.
        """
227
        assert self.temporal_patch_size > 1
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        T = self.temporal_patch_size
        input_size = x.shape[2:]

        patches = self.im_to_patches(x)  # [N, num_patches, 3*P*P]
        num_frames, num_spatial, feat_dim = patches.shape

        # Pad to a multiple of T by repeating the last frame so that
        # all tubelets have exactly T frames.
        num_pad_frames = (-num_frames) % T
        if num_pad_frames > 0:
            last_frame_dup = patches[-1:].expand(num_pad_frames, -1, -1)
            patches = torch.cat([patches, last_frame_dup], dim=0)

        # Group T frames per tubelet: for each spatial position, concatenate
        #   features across T consecutive frames; order follows Megatron training
        num_frames_padded = patches.shape[0]
        num_tublets = num_frames_padded // T
        patches = rearrange(
            patches,
            "(tubelets frames) spatial feat -> tubelets spatial (frames feat)",
            tubelets=num_tublets,
            frames=T,
            spatial=num_spatial,
            feat=feat_dim,
        )

        patches = self.video_embedder(patches)

        patches, pos_enc = self.apply_pos_enc(patches, input_size=input_size)

        patches = self.cls_token(patches)

        patches = self.patch_normalizer(patches)
        if self.return_pos_enc:
            return patches, pos_enc
        return patches

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    def apply_pos_enc_dynamic(
        self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        if not self.abs_pos:
            return patches, None

        current_length = 0
        pos_enc_list = []

        for size in imgs_sizes:
            seq_length = calc_seq_len(size, self.patch_size)

            img_patches = patches[:, current_length : current_length + seq_length, :]
            pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
            img_patches_with_pos = img_patches + pos_enc

            patches = torch.cat(
                [
                    patches[:, :current_length, :],
                    img_patches_with_pos,
                    patches[:, current_length + seq_length :, :],
                ],
                dim=1,
            )
            pos_enc_list.append(pos_enc)
            current_length += seq_length

        full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
        return patches, full_pos_enc

    def cls_token_dynamic(
        self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
    ) -> torch.Tensor:
        if not self.cls_token.enabled:
            return patches

        out = []
        current_length = 0

        for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
            class_token = self.cls_token.token.unsqueeze(0).expand(
                patches.shape[0], -1, -1
            )
            out.append(class_token)
            out.append(patches[:, current_length : current_length + seq_len, :])
            current_length += seq_len

        return torch.cat(out, dim=1)

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    @property
    def apply_cls_token(self):
        return self.cls_token.enabled

    @property
    def num_cls_tokens(self):
        return self.cls_token.num_tokens

    @property
    def num_cls_patches(self):
        return self.cls_token.num_patches

    @property
    def num_registers(self):
        return self.cls_token.num_registers

    @property
    def num_skip(self):
        return self.num_cls_tokens + self.num_registers

    def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
        if src_embed.shape != targ_embed.shape:
            src_size = int(math.sqrt(src_embed.shape[1]))

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            assert src_size**2 == src_embed.shape[1], (
                "Unable to interpolate non-square embedding"
            )

            src_embed = rearrange(
                src_embed, "b (h w) c -> b c h w", h=src_size, w=src_size
            )
            src_embed = F.interpolate(
                src_embed,
                size=(self.num_rows, self.num_cols),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_embed = rearrange(src_embed, "b c h w -> b (h w) c")
353
354
        targ_embed.data.copy_(src_embed)

355
356
357
    def _load_projection(
        self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor
    ):
358
359
360
        if src_proj_weight.shape != targ_proj_weight.shape:
            src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            assert (src_patch_size**2) * 3 == src_proj_weight.shape[1], (
                "Unable to interpolate non-square patch size"
            )

            src_proj_weight = rearrange(
                src_proj_weight,
                "b (c h w) -> b c h w",
                c=3,
                h=src_patch_size,
                w=src_patch_size,
            )
            src_proj_weight = F.interpolate(
                src_proj_weight,
                size=(self.patch_size, self.patch_size),
                mode="bicubic",
                align_corners=True,
                antialias=False,
            )
            src_proj_weight = rearrange(src_proj_weight, "b c h w -> b (c h w)")
380
381
382
383
384
385
386
387
388
389
        targ_proj_weight.data.copy_(src_proj_weight)

    def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
        patches = self.im_to_patches(x)
        patches = self.embedder(patches)
        return patches

    def apply_pos_enc(
        self,
        patches: torch.Tensor,
390
391
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
392
393
394
395
396
397
398
    ) -> torch.Tensor:
        if not self.abs_pos:
            return patches

        pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)

        if self.training and self.pos_dropout > 0:
399
400
401
402
403
404
            keeps = (
                torch.rand(
                    patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device
                )
                > self.pos_dropout
            )
405
406
407
408
409
410
411
412
413
            pos_enc_drop = torch.where(keeps, pos_enc, 0)
        else:
            pos_enc_drop = pos_enc

        return patches + pos_enc_drop, pos_enc

    def get_pos_enc(
        self,
        batch_size: int,
414
415
        patch_idxs: torch.Tensor | None = None,
        input_size: tuple[int, int] | None = None,
416
417
418
419
420
421
422
423
424
425
426
    ) -> torch.Tensor:
        if input_size is None:
            input_dims = self.input_dims
        else:
            input_dims = tuple(d // self.patch_size for d in input_size)

        pos_embed = self._get_pos_embeddings(batch_size, input_dims)

        if patch_idxs is None:
            return pos_embed

427
        exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
428

429
430
431
        pos_embed = torch.gather(
            pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs
        )
432
433
        return pos_embed

434
    def _get_pos_embeddings(self, batch_size: int, input_dims: tuple[int, int]):
435
436
437
        if (self.num_rows, self.num_cols) == input_dims:
            return self.pos_embed

438
439
440
        pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(
            0, 3, 1, 2
        )
441
442
443

        def window_select(pos_embed):
            if input_dims[0] < pos_embed.shape[-2]:
444
                pos_embed = pos_embed[..., : input_dims[0], :]
445
            if input_dims[1] < pos_embed.shape[-1]:
446
                pos_embed = pos_embed[..., :, : input_dims[1]]
447
448
449
            return pos_embed

        if self.cpe_mode:
450
451
452
453
454
455
456
            max_dim = max(input_dims)
            pos_embed = F.interpolate(
                pos_embed.float(),
                size=(max_dim, max_dim),
                align_corners=False,
                mode="bilinear",
            ).to(pos_embed.dtype)
457

458
            pos_embed = window_select(pos_embed)
459
460
461
462
        else:
            pos_embed = window_select(pos_embed)

        if pos_embed.shape[-2:] != input_dims:
463
            pos_embed = F.interpolate(
464
                pos_embed.float(), size=input_dims, align_corners=False, mode="bilinear"
465
            ).to(pos_embed.dtype)
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486

        pos_embed = pos_embed.flatten(2).permute(0, 2, 1)

        return pos_embed


class Im2Patches(nn.Module):
    def __init__(self, patch_size: int):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.patch_size == 1:
            patches = x.flatten(2)
            patches = patches.permute(0, 2, 1)
            return patches

        py = x.shape[-2] // self.patch_size
        px = x.shape[-1] // self.patch_size
        patches = rearrange(
            x,
487
            "b c (py yy) (px xx) -> b (py px) (c yy xx)",
488
489
490
491
492
493
494
495
496
            py=py,
            yy=self.patch_size,
            px=px,
            xx=self.patch_size,
        )
        return patches


class ViTPatchLinear(nn.Linear):
497
498
499
500
501
502
503
504
505
506
507
    def __init__(
        self,
        patch_size: int,
        embed_dim: int,
        bias: bool = False,
        temporal_patch_size: int = 1,
        **factory,
    ):
        super().__init__(
            3 * temporal_patch_size * (patch_size**2), embed_dim, bias=bias, **factory
        )
508
        self.patch_size = patch_size
509
        self.temporal_patch_size = temporal_patch_size
510
511


512
513
514
515
516
517
@dataclass(frozen=True, kw_only=True)
class MaskMetadata:
    cu_seqlens: torch.Tensor
    max_seqlen: torch.Tensor


518
519
class RadioParallelAttention(InternParallelAttention):
    def forward(
520
        self, x: torch.Tensor, mask_meta: MaskMetadata | None = None
521
522
523
524
525
526
527
    ) -> torch.Tensor:
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        if self.qk_normalization:
            q, k = self._apply_qk_norm(q, k)

528
529
530
531
532
        cu_seqlens, max_seqlen = None, None
        if mask_meta is not None:
            cu_seqlens = mask_meta.cu_seqlens
            max_seqlen = mask_meta.max_seqlen
        out = self.attn(q, k, v, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
533
534
535
536
537
538
539
540
541
542
543
        out, _ = self.proj(out)
        return out


class RadioVisionEncoderLayer(InternVisionEncoderLayer):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)

    def forward(
        self,
        hidden_states: torch.Tensor,
544
        mask_meta: MaskMetadata | None = None,
545
546
547
    ):
        hidden_states = (
            hidden_states
548
            + self.attn(self.norm1(hidden_states), mask_meta=mask_meta) * self.ls1
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        )

        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2

        return hidden_states


class RadioVisionEncoder(InternVisionEncoder):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)

    def forward(
        self,
        inputs_embeds: torch.Tensor,
563
        mask_meta: MaskMetadata | None = None,
564
565
566
    ):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
567
            hidden_states = encoder_layer(hidden_states, mask_meta=mask_meta)
568
569
570
        return hidden_states


571
572
573
574
575
576
577
578
class RadioInternVisionModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig = None,
579
        quant_config: QuantizationConfig | None = None,
580
        *,
581
        num_hidden_layers_override: int | None = None,
582
583
584
585
586
587
588
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.img_size, self.grid_size, self.num_patches = self._init_img_size(
589
590
            to_2tuple(config.patch_size), config.image_size
        )
591
        max_img_size = int(
592
            round(config.cpe_max_size / config.patch_size) * config.patch_size
593
        )
594
        self.temporal_patch_size = config.video_temporal_patch_size
595
        unique_teachers = set(t["name"] for t in config.teachers)
596
597
598
599
600
601
        self.patch_generator = ViTPatchGenerator(
            config.patch_size,
            config.hidden_size,
            input_dims=self.img_size,
            max_input_dims=max_img_size,
            cls_token=True,
602
603
            num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
            register_multiple=config.register_multiple,
604
605
            temporal_patch_size=self.temporal_patch_size,
            separate_video_embedder=config.separate_video_embedder,
606
        )
607

608
        self.encoder = RadioVisionEncoder(
609
610
611
612
613
614
615
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
            prefix=f"{prefix}.encoder",
        )

616
    def _init_img_size(self, patch_size, img_size: int | tuple[int, int]):
617
618
619
620
621
622
623
624
625
626
        if img_size is None:
            return None, None, None
        img_size = to_2tuple(img_size)
        grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
        num_patches = grid_size[0] * grid_size[1]
        return img_size, grid_size, num_patches

    def get_input_embeddings(self):
        return self.embeddings

627
    def inter_image_mask_metadata(
628
        self, imgs_sizes: list[tuple[int, int]], device: torch.device
629
    ) -> MaskMetadata:
630
631
        """Build mask metadata from image pixel sizes. Adds num_skip to each
        sequence length (cls/register tokens) to match patch generator output."""
632
633
634
635
        patch_size = self.patch_generator.patch_size
        num_skip = self.patch_generator.num_skip

        seq_lens = calc_seq_lens(imgs_sizes, patch_size)
636
        adjusted = [s + num_skip for s in seq_lens]
637
638
639
640
641
642
643
644
645
        return self._inter_image_mask_metadata_from_seq_lens(adjusted, device=device)

    def _inter_image_mask_metadata_from_seq_lens(
        self, seq_lens: list[int], device: torch.device
    ) -> MaskMetadata:
        """Build mask metadata from actual sequence lengths (already including
        cls/register tokens, i.e. patch_count + num_skip per item).
        Use inter_image_mask_metadata() when you only have imgs_sizes."""
        assert len(seq_lens) > 0
646
        cu_seqlens = torch.tensor(
647
            list(accumulate(seq_lens, initial=0)), dtype=torch.int32, device=device
648
        )
649
650
        # Keep max_seqlen on CPU to avoid .item() sync
        # See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
651
        max_seqlen = torch.tensor(max(seq_lens), dtype=torch.int32)
652
        return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
653
654
655
656

    def forward(
        self,
        x: torch.Tensor,
657
        imgs_sizes: list[tuple[int, int]] | None = None,
658
        num_frames: int | None = None,
659
    ) -> torch.FloatTensor:
660
661
662
        T = self.temporal_patch_size

        # Build packed-sequence metadata for MMEncoderAttention when needed.
663
        mask_meta = None
664
665
666
667
668
669
670
671
672
673
        packed_batch_size = None  # Original batch size before packing

        if num_frames is not None and T > 1:
            # Conv3d video: all tubelets have the same sequence length.
            # Pack [num_tubelets, seq_per_tubelet, hidden] → [1, total, hidden]
            hidden_states = self.patch_generator.forward_video(x)
            packed_batch_size, seq_per_tubelet, hidden_dim = hidden_states.shape
            hidden_states = hidden_states.reshape(1, -1, hidden_dim)
            mask_meta = self._inter_image_mask_metadata_from_seq_lens(
                [seq_per_tubelet] * packed_batch_size, device=hidden_states.device
674
            )
675
676
677
678
679
680
681
682
683
        else:
            # Images for any model, or video for non-conv3d model
            hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
            if imgs_sizes is not None and len(imgs_sizes) > 1:
                # Dynamic resolution w/ > 1 image, create attn mask
                mask_meta = self.inter_image_mask_metadata(
                    imgs_sizes, device=hidden_states.device
                )

684
        encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
685
686
687
688
689
690
691

        # Unpack back to original batch shape if we packed for video
        if packed_batch_size is not None:
            encoder_outputs = encoder_outputs.reshape(
                packed_batch_size, seq_per_tubelet, -1
            )

692
693
694
695
696
697
698
699
700
701
702
        return encoder_outputs


class RadioModel(nn.Module):
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

    def __init__(
        self,
        config: PretrainedConfig,
703
        quant_config: QuantizationConfig | None = None,
704
        *,
705
        num_hidden_layers_override: int | None = None,
706
707
708
709
710
711
712
713
714
715
716
        num_dummy_heads: int = 0,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config
        self.model = RadioInternVisionModel(
            config=config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
717
718
            prefix=prefix,
        )
719

720
721
722
723
724
725
726
727
728
        summary_idxs = None
        if config.teachers:
            summary_idxs = torch.tensor(
                [i for i, t in enumerate(config.teachers) if t.get("use_summary", True)]
            )
            if summary_idxs.numel() > 0:
                self.register_buffer("summary_idxs", summary_idxs)
        self.summary_idxs = summary_idxs

729
730
    def forward(
        self,
731
732
        pixel_values: torch.Tensor | None = None,
        pixel_embeds: torch.Tensor | None = None,
733
        *,
734
        imgs_sizes: list[tuple[int, int]] | None = None,
735
        num_frames: int | None = None,
736
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
737
738
739
740
741
        y = self.model(
            pixel_values,
            imgs_sizes=imgs_sizes,
            num_frames=num_frames,
        )
742
        return self._extract_final(y, imgs_sizes=imgs_sizes)
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757

    def load_weights(self, weights) -> set[str]:
        loaded_params: set[str] = set()
        params_dict = dict(self.named_parameters())

        if isinstance(weights, dict):
            weights_list = list(weights.items())
        else:
            weights_list = list(weights)

        for name, weight in weights_list:
            if not name.startswith("radio_model."):
                # Skip non-radio weights
                continue

758
            sub = name[len("radio_model.") :]  # drop "radio_model." prefix
759
760
761
762

            # Skip buffers not used in vLLM
            if sub in {"summary_idxs"}:
                continue
763
764
765
766
            if sub.startswith("input_conditioner."):
                # we normalize in the input processor,
                # based on norm and std values from the config
                continue
767
768
769
770
771
772
773
774
775
776
777
778
779
780

            vllm_key = None
            if sub.startswith("model.patch_generator."):
                vllm_key = f"model.patch_generator.{sub.split('.', 2)[-1]}"
            elif sub.startswith("input_conditioner."):
                vllm_key = f"input_conditioner.{sub.split('.', 1)[-1]}"
            elif sub.startswith("model.blocks."):
                # Encoder blocks: HF 'model.blocks.{i}.' ->
                # vLLM 'model.encoder.layers.{i}.'
                parts = sub.split(".")
                if len(parts) >= 4:
                    layer_idx = parts[2]
                    suffix = ".".join(parts[3:])
                    # Skip layer-scale entries that vLLM doesn't use
781
                    if suffix in {"ls1", "ls2"} or suffix.startswith(("ls1.", "ls2.")):
782
783
784
785
786
                        continue
                    vllm_key = f"model.encoder.layers.{layer_idx}.{suffix}"

            if vllm_key and vllm_key in params_dict:
                param = params_dict[vllm_key]
787
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
788
789
790
791
792
                weight_loader(param, weight)
                loaded_params.add(vllm_key)

        return loaded_params

793
    def _extract_final(
794
        self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
795
    ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
796
        # Remove CLS + REGISTERS tokens
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
        num_skip = self.model.patch_generator.num_skip
        patch_size = self.model.patch_generator.patch_size
        num_cls_tokens = self.model.patch_generator.num_cls_tokens
        if imgs_sizes is None:
            all_summary = y[:, :num_cls_tokens]
            all_feat = y[:, num_skip:]
        else:
            all_patches = []
            summaries = []
            current_pos = 0
            for num_patches in calc_seq_lens(imgs_sizes, patch_size):
                patches = y[
                    :, current_pos + num_skip : current_pos + num_skip + num_patches, :
                ]
                all_patches.append(patches)
                summary = y[:, current_pos : current_pos + num_cls_tokens, :]
                summaries.append(summary)
                current_pos += num_skip + num_patches
            all_summary = torch.cat(summaries, dim=1)
            all_feat = torch.cat(all_patches, dim=1)

        if self.summary_idxs is not None:
            bb_summary = all_summary[:, self.summary_idxs]
        else:
            bb_summary = all_summary
822
        return bb_summary.flatten(1), all_feat