"vllm/vscode:/vscode.git/clone" did not exist on "24a03915f525b88ebc4c36127c3e9ccf56dc21ee"
moonvit.py 24.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# ruff: noqa: E501
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
# This file is meant to be used in kimi_vl.py only
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
#
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
#
# Licensing Information:
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
# - Other parts of the code are licensed under the MIT License.
#
# Apache License, Version 2.0:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License:
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
45
from collections.abc import Sequence
46
47
48
49
50
51
from copy import deepcopy
from functools import cached_property

import torch
import torch.nn as nn
import torch.nn.functional as F
52
from transformers.activations import ACT2FN
53
54
55
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import is_flash_attn_2_available

56
from vllm.model_executor.layers.conv import Conv2dLayer
57
58
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.models.utils import maybe_prefix
59
from vllm.platforms import current_platform
60
61
62
63
from vllm.transformers_utils.configs.moonvit import MoonViTConfig

if is_flash_attn_2_available():
    from flash_attn import flash_attn_varlen_func
64
65
elif current_platform.is_xpu():
    from vllm.attention.utils.fa_utils import flash_attn_varlen_func
66
67
68
69
70
71
72
73
else:
    flash_attn_varlen_func = None


def multihead_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
74
75
    q_cu_seqlens: torch.Tensor | None = None,
    k_cu_seqlens: torch.Tensor | None = None,
76
) -> torch.Tensor:
77
78
79
    """Multi-head attention using flash attention 2.

    Args:
80
81
82
83
84
        q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
            or (tot_seqlens, num_heads, head_dim) if packing.
        k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
            or (tot_seqlens, num_heads, head_dim) if packing.
        v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
85
86
87
88
89
90
91
92
93
94
95
96
            or (tot_seqlens, num_heads, head_dim) if packing.
        q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
            The first element should be 0 and the last element should be q.shape[0].
        k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
            The first element should be 0 and the last element should be k.shape[0].

    Returns:
        output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
            where dim = num_heads * head_dim
    """
    # Unified format legal check
    assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
97
98
99
100
    assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
    assert k_cu_seqlens[-1] == k.shape[0] == v.shape[0], (
        "k_cu_seqlens must sum to k.shape[0]"
    )
101
102
103
104
105
106
107
108
109
110
111
    assert q.dtype in [
        torch.bfloat16,
        torch.float16,
    ], f"unsupported dtype {q.dtype} for multihead attn"

    max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
    max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
    attn_out = flash_attn_varlen_func(
        q,
        k,
        v,
112
113
114
115
        cu_seqlens_q=q_cu_seqlens,
        cu_seqlens_k=k_cu_seqlens,
        max_seqlen_q=max_seqlen_q,
        max_seqlen_k=max_seqlen_k,
116
117
118
119
120
121
122
123
124
125
126
        causal=False,
    )
    attn_out = attn_out.flatten(start_dim=-2)

    return attn_out


def sdpa_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
127
128
    q_cu_seqlens: torch.Tensor | None = None,
    k_cu_seqlens: torch.Tensor | None = None,
129
130
131
132
) -> torch.Tensor:
    """SDPA attention.

    Args:
133
134
135
136
137
        q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
            or (tot_seqlens, num_heads, head_dim) if packing.
        k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
            or (tot_seqlens, num_heads, head_dim) if packing.
        v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
138
            or (tot_seqlens, num_heads, head_dim) if packing.
139
140
        q_cu_seqlens: Optional cumulative sequence lengths of q.
        k_cu_seqlens: Optional cumulative sequence lengths of k.
141
142
    """
    seq_length = q.shape[0]
143
144
145
    attention_mask = torch.zeros(
        [1, seq_length, seq_length], device=q.device, dtype=torch.bool
    )
146
147
148
    for i in range(1, len(q_cu_seqlens)):
        attention_mask[
            ...,
149
150
            q_cu_seqlens[i - 1] : q_cu_seqlens[i],
            q_cu_seqlens[i - 1] : q_cu_seqlens[i],
151
152
153
154
        ] = True
    q = q.transpose(0, 1)
    k = k.transpose(0, 1)
    v = v.transpose(0, 1)
155
    attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    attn_output = attn_output.transpose(0, 1)
    attn_output = attn_output.reshape(seq_length, -1)
    return attn_output


VL_VISION_ATTENTION_FUNCTIONS = {
    "flash_attention_2": multihead_attention,
    "sdpa": sdpa_attention,
}


def _apply_rope_input_validation(x, freqs_cis):
    assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
    assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
    assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
    assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype


174
175
176
def apply_rope(
    xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    """
    Args: (The leading dimensions of all inputs should be the same)
        xq: query, tensor of shape (..., num_heads, head_dim)
        xk: key, tensor of shape (..., num_heads, head_dim)
        freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
    Returns:
        xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
    """
    _apply_rope_input_validation(xq, freqs_cis)
    _apply_rope_input_validation(xk, freqs_cis)

    freqs_cis = freqs_cis.unsqueeze(-2)  # ..., 1, head_dim/2
    # ..., num_heads, head_dim/2
    xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
192
193
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)  # ..., num_heads, head_dim
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)  # ..., num_heads, head_dim
194
195
196
197
    return xq_out.type_as(xq), xk_out.type_as(xk)


class Learnable2DInterpPosEmb(nn.Module):
198
199
200
    def __init__(
        self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
    ) -> None:
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        super().__init__()
        self.height = height
        self.width = width
        self.interpolation_mode = interpolation_mode
        self.weight = nn.Parameter(torch.empty(height, width, dim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight)

    def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
        pos_embs = []
        for shape in grid_hws.tolist():
            if shape == self.weight.shape[:-1]:
                pos_embs.append(self.weight.flatten(end_dim=1))
            else:
                pos_embs.append(
                    F.interpolate(
                        self.weight.permute((2, 0, 1)).unsqueeze(0),
                        size=shape,
                        mode=self.interpolation_mode,
222
223
224
225
226
                    )
                    .squeeze(0)
                    .permute((1, 2, 0))
                    .flatten(end_dim=1)
                )
227
228
229
230
231
232
233
234
235
        out = x + torch.cat(pos_embs)
        return out


class MoonVisionPatchEmbed(nn.Module):
    def __init__(
        self,
        out_dim: int,
        in_dim: int = 3,
236
        patch_size: int | tuple[int, int] = (14, 14),
237
238
239
240
        pos_emb_height: int = 14,
        pos_emb_width: int = 14,
    ):
        super().__init__()
241
242
243
        assert isinstance(patch_size, (int, Sequence)), (
            f"Invalid patch_size type: {type(patch_size)}"
        )
244
245
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
246
247
248
        assert len(patch_size) == 2, (
            f"Expected patch_size to be a tuple of 2, got {patch_size}"
        )
249
250
        self.patch_size = patch_size

251
        self.proj = Conv2dLayer(
252
253
            in_dim, out_dim, kernel_size=patch_size, stride=patch_size
        )
254

255
256
257
        self.pos_emb = Learnable2DInterpPosEmb(
            height=pos_emb_height, width=pos_emb_width, dim=out_dim
        )
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (L, Channels): input tensor
            grid_hw (N, 2): grid height and width

        Returns:
            (L, Cout) tensor
        """
        x = self.proj(x).view(x.size(0), -1)
        # apply positional embedding
        x = self.pos_emb(x, grid_hw)
        return x


class Rope2DPosEmb(nn.Module):
    """2D rotary position embedding with multi-resolution support.

    This class is intended to be used in the following way:
    1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
    2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
    3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
        The rope is shared across all attention layers and all heads.

    Refs:
    - RoFormer: https://arxiv.org/abs/2104.09864
    - VisionLLaMA: https://arxiv.org/abs/2403.00522
    - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py

    Args:
        dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
        max_height (int): the maximum height of the 2D grid
        max_width (int): the maximum width of the 2D grid
        theta_base (float): the base of the theta
        device (str): the device to store the precomputed cis
    """

296
    def __init__(
297
298
299
300
301
302
        self,
        dim: int,
        max_height: int,
        max_width: int,
        theta_base=10000,
        device=current_platform.device_type,
303
    ):
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        super().__init__()
        self.dim = dim
        assert self.dim % 4 == 0, "dim must be divisible by 4"
        self.max_height = max_height
        self.max_width = max_width
        self.theta_base = theta_base
        self.device = device

    def extra_repr(self):
        return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"

    @cached_property
    def precomputed_freqs_cis(self) -> torch.Tensor:
        """Calculate the cis(freqs) for each position in the 2D grid.

        Return: complex tensor of shape (max_height, max_width, dim//2) and value:
            height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
            weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim))   with (i in [0, dim//4))
            note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
        """
        N = self.max_height * self.max_width
        flat_pos = torch.arange(0, N).float().to(self.device)
        x_pos = flat_pos % self.max_width
        y_pos = flat_pos // self.max_width
328
329
330
331
        dim_range = (
            torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device)
        )  # C/4
        freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
332
333
334
335
336
337
        x_freqs = torch.outer(x_pos, freqs).float()  # N, C/4
        y_freqs = torch.outer(y_pos, freqs).float()  # N, C/4
        x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)  # N, C/4
        y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)  # N, C/4
        # N, C/4, 2
        freqs_cis = torch.cat(
338
339
            [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
        )
340
341
342
343
344
345
346
347
348
349
350
351
        # max_height, max_width, C/2
        freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
        return freqs_cis

    def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
        """
        Args:
            grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
        Returns:
            freqs_cis: tensor of shape (sum(t * height * width), dim//2)
        """
        shapes = grid_hws.tolist()
352
353
354
355
356
357
358
        assert all(
            1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
        ), (
            shapes,
            self.max_height,
            self.max_width,
        )
359
360
361
362
363
364
365
366
367
        freqs_cis = torch.cat(
            [
                self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
                for h, w in shapes
            ],
            dim=0,
        )
        return freqs_cis

368
369
370
    def get_freqs_cis_by_idx(
        self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
    ) -> torch.Tensor:
371
372
373
374
375
376
377
378
        """
        Args:
            pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
            pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
                Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
        Return:
            freqs_cis: tensor of shape (..., dim//2)
        """
379
380
381
382
383
        assert (
            pos_idx.shape[:-1] == pos_idx_mask.shape
            and pos_idx.shape[-1] == 2
            and pos_idx.ndim == pos_idx_mask.ndim + 1
        ), (pos_idx.shape, pos_idx_mask.shape)
384
385
        assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype

386
387
388
389
390
391
392
        shp = pos_idx_mask.shape + (self.dim // 2,)  # ..., head_dim/2
        freqs_cis = torch.ones(
            shp, dtype=torch.complex64, device=self.device
        )  # ..., head_dim/2
        freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
            pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
        ]
393
394
395
396
397
398
399
400
401
402
        return freqs_cis


class MLP2(nn.Module):
    """
    Args:
        dims: [in_dim, hidden_dim, out_dim]
        bias: whether to use bias in linear layer.
    """

403
404
405
406
407
408
409
410
    def __init__(
        self,
        dims: list[int],
        activation,
        bias: bool = True,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
411
412
        super().__init__()
        assert len(dims) == 3
413
        self.use_data_parallel = use_data_parallel
414
415
416
417
418
419
        self.fc0 = ReplicatedLinear(
            dims[0], dims[1], bias=bias, prefix=maybe_prefix(prefix, "fc0")
        )
        self.fc1 = ReplicatedLinear(
            dims[1], dims[2], bias=bias, prefix=maybe_prefix(prefix, "fc1")
        )
420
421
422
        self.activation = activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
423
        x, _ = self.fc0(x)
424
        x = self.activation(x)
425
426
        x, _ = self.fc1(x)
        return x
427
428
429
430
431
432
433
434


class MoonVitEncoderLayer(nn.Module):
    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
435
436
        prefix: str = "",
        use_data_parallel: bool = False,
437
438
439
440
441
442
443
444
445
446
447
        *,
        attn_implementation: str = "sdpa",
        activation=F.gelu,
        attn_bias: bool = False,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
        self.attn_implementation = attn_implementation
        # use fa2 in vllm by default
448
        if is_flash_attn_2_available() or current_platform.is_xpu():
449
450
451
452
            self.attn_implementation = "flash_attention_2"

        self.norm0 = nn.LayerNorm(hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
453
        self.use_data_parallel = use_data_parallel
454
455
456
457
458
459
460
461
462
463
464
465
        self.mlp = MLP2(
            [hidden_dim, mlp_dim, hidden_dim],
            activation,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
        self.wqkv = ReplicatedLinear(
            hidden_dim, hidden_dim * 3, bias=attn_bias, prefix=f"{prefix}.wqkv"
        )
        self.wo = ReplicatedLinear(
            hidden_dim, hidden_dim, bias=attn_bias, prefix=f"{prefix}.wo"
        )
466
467
468
469
470

    def attention_qkvpacked(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
471
        rope_freqs_cis: torch.Tensor | None = None,
472
473
474
475
476
477
    ):
        """
        Args:
            x (torch.Tensor): (batch_size, seqlen, hidden_dim)
            cu_seqlens (torch.Tensor):
        """
478
        xqkv, _ = self.wqkv(x)
479
480
481
482
483
484
485
486
487
488
489
490
491

        qkv_shape = xqkv.size()[:-1] + (
            3,
            self.num_heads,
            self.hidden_size_per_attention_head,
        )
        # xqkv: (batch_size, seqlen, 3, nheads, headdim)
        xqkv = xqkv.view(*qkv_shape)
        xq, xk, xv = torch.unbind(xqkv, dim=-3)

        xq, xk = apply_rope(xq, xk, rope_freqs_cis)

        attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
492
493
494
        attn_out = attn_func(
            xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
        )
495
        attn_out, _ = self.wo(attn_out)
496
497
498
499
500
501
        return attn_out

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
502
        rope_freqs_cis: torch.Tensor | None = None,
503
504
505
506
507
508
509
510
511
512
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set

        Returns:
            output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
        """
        residual = hidden_states
        hidden_states = self.norm0(hidden_states)
513
514
515
        attn_out = self.attention_qkvpacked(
            hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
        )
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        hidden_states = residual + attn_out

        residual = hidden_states
        hidden_states = self.mlp(self.norm1(hidden_states))
        hidden_states = residual + hidden_states
        return hidden_states


class MoonVitEncoder(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        num_layers: int,
        block_cfg: dict,
530
531
        prefix: str = "",
        use_data_parallel: bool = False,
532
533
534
535
    ) -> None:
        super().__init__()

        self.rope_2d = Rope2DPosEmb(
536
537
            block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
        )
538
        self.blocks = nn.ModuleList(
539
540
541
542
543
544
545
546
547
            [
                MoonVitEncoderLayer(
                    use_data_parallel=use_data_parallel,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                    **block_cfg,
                )
                for layer_idx in range(num_layers)
            ]
        )
548
549
        self.final_layernorm = nn.LayerNorm(hidden_dim)

550
551
552
553
    def forward(
        self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
    ) -> torch.Tensor:
        rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw)
554

555
        lengths = torch.cat(
556
557
558
559
560
            (
                torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
                (grid_hw[:, 0] * grid_hw[:, 1]).to(hidden_states.device),
            )
        )
561
562
563
        cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)

        for _, block in enumerate(self.blocks):
564
565
566
            hidden_states = block(
                hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
            )
567
568
569
570
571
572
573

        hidden_states = self.final_layernorm(hidden_states)

        return hidden_states


def patch_merger(
574
575
576
    x: torch.Tensor,
    grid_hw: torch.Tensor,
    merge_kernel_size: list[int, int] = (2, 2),
577
) -> list[torch.Tensor]:
578
579
580
581
582
583
584
    d_model = x.size(-1)

    outputs = []
    pre_sum = 0
    for x_shape in grid_hw.tolist():
        height, width = x_shape[0], x_shape[1]
        # Get the current sequence
585
        seq = x[pre_sum : pre_sum + height * width]
586
587
588
        # Reshape along self.merge_kernel_size and concat to the last dimension
        kernel_height, kernel_width = merge_kernel_size
        new_height, new_width = height // kernel_height, width // kernel_width
589
590
591
        reshaped_seq = seq.view(
            new_height, kernel_height, new_width, kernel_width, d_model
        )
592
        reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
593
594
595
        padded_seq = reshaped_seq.view(
            new_height * new_width, kernel_height * kernel_width, -1
        )
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
        outputs.append(padded_seq)
        pre_sum += height * width

    return outputs


class MoonVitVLProjector(nn.Module):
    def __init__(
        self,
        in_channels: int,
        merge_kernel_size: list[int, int],
        hidden_act: str = "gelu",
        ln_eps: float = 1e-5,
        out_dim: int = 4096,
    ):
        super().__init__()
612
        self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
613
614

        self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
615
        self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        self.act = ACT2FN[hidden_act]
        self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
        hidden_states = self.linear_1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class MoonVitPretrainedModel(PreTrainedModel):
    config_class = MoonViTConfig
    model_type = "moonvit"
    _no_split_modules = ["PackingTransformer"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True

634
635
636
637
638
639
640
641
    def __init__(
        self,
        config: MoonViTConfig,
        use_data_parallel: bool = False,
        prefix: str = "",
        *inputs,
        **kwargs,
    ):
642
643
        super().__init__(config, *inputs, **kwargs)
        config = deepcopy(config)
644
        self.use_data_parallel = use_data_parallel
645
        self.merge_kernel_size = config.merge_kernel_size
646
        self.hidden_size = config.hidden_size
647
        self.patch_size = config.patch_size
648
        self.vit_processing_type = "rope_2d"
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        self.patch_embed = MoonVisionPatchEmbed(
            out_dim=config.hidden_size,
            patch_size=config.patch_size,
            pos_emb_height=config.init_pos_emb_height,
            pos_emb_width=config.init_pos_emb_width,
        )

        self.encoder = MoonVitEncoder(
            hidden_dim=config.hidden_size,
            num_layers=config.num_hidden_layers,
            block_cfg={
                "num_heads": config.num_attention_heads,
                "hidden_dim": config.hidden_size,
                "mlp_dim": config.intermediate_size,
663
                "activation": ACT2FN["gelu_pytorch_tanh"],
664
665
666
                "attn_bias": True,
                "attn_implementation": config._attn_implementation,
            },
667
            prefix=f"{prefix}.encoder",
668
669
        )

670
671
672
    def forward(
        self, pixel_values: torch.Tensor, grid_hw: torch.Tensor
    ) -> torch.Tensor:
673
674
675
676
677
678
679
680
681
682
        """
        Args:
            pixel_values (torch.Tensor): The input pixel values.
            grid_hw (torch.Tensor): The grid height and width.

        Returns:
            torch.Tensor: The output tokens.
        """
        hidden_states = self.patch_embed(pixel_values, grid_hw)
        hidden_states = self.encoder(hidden_states, grid_hw)
683
684
685
        hidden_states = patch_merger(
            hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size
        )
686
        return hidden_states