dots_ocr.py 30.5 KB
Newer Older
Roger Wang's avatar
Roger Wang committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
4
from typing import Annotated, Literal, TypeAlias
Roger Wang's avatar
Roger Wang committed
5
6
7
8
9
10
11

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

12
from vllm.attention.backends.registry import AttentionBackendEnum
13
14
15
16
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
Roger Wang's avatar
Roger Wang committed
17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
20
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
21
22
23
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Roger Wang's avatar
Roger Wang committed
24
25
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
26
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Roger Wang's avatar
Roger Wang committed
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
35
36
37
38
39
from vllm.model_executor.models.interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
40
from vllm.model_executor.models.module_mapping import MultiModelKeys
Roger Wang's avatar
Roger Wang committed
41
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
42
from vllm.model_executor.models.qwen2_vl import (
43
    Qwen2VisionAttention,
44
45
46
47
48
49
50
51
52
53
    Qwen2VLDummyInputsBuilder,
    Qwen2VLMultiModalProcessor,
    Qwen2VLProcessingInfo,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Roger Wang's avatar
Roger Wang committed
54
55
56
57
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.sequence import IntermediateTensors
58
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
60

61
62
from .vision import run_dp_sharded_mrope_vision_model

Roger Wang's avatar
Roger Wang committed
63
64
65
IMAGE_TOKEN = "<|imgpad|>"


66
67
68
69
70
71
72
73
class DotsOCRImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    """
74

75
    type: Literal["pixel_values"]
Roger Wang's avatar
Roger Wang committed
76

77
78
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
79
80


81
class DotsOCRImageEmbeddingInputs(TensorSchema):
Roger Wang's avatar
Roger Wang committed
82
    """
83
84
85
86
87
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
    """
88

89
    type: Literal["image_embeds"]
Roger Wang's avatar
Roger Wang committed
90

91
92
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
93
94


95
DotsOCRImageInputs: TypeAlias = DotsOCRImagePixelInputs | DotsOCRImageEmbeddingInputs
Roger Wang's avatar
Roger Wang committed
96
97
98
99
100
101
102
103
104
105
106


class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
107
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Roger Wang's avatar
Roger Wang committed
108
109
110
111
112
113
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features(  # noqa: E501
        )

114
115
        image_overrides = mm_options.get("image") if mm_options else None

Roger Wang's avatar
Roger Wang committed
116
        return {
117
118
119
120
121
122
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
Roger Wang's avatar
Roger Wang committed
123
124
125
126
127
128
        }


class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> DotsOCRConfig:
        config = self.ctx.get_hf_config()
129
        if not config.__class__.__name__ == "DotsOCRConfig":
Roger Wang's avatar
Roger Wang committed
130
131
            raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")

132
        if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
Roger Wang's avatar
Roger Wang committed
133
134
135
136
            config.vision_config = DotsVisionConfig(**config.vision_config)

        return config

137
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Roger Wang's avatar
Roger Wang committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        return {"image": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        return {"image": max_image_tokens}

    def get_hf_processor(
        self,
        **kwargs: object,
    ) -> Qwen2VLProcessor:
152
        self.get_tokenizer().image_token = IMAGE_TOKEN  # Ensure image token is set
Roger Wang's avatar
Roger Wang committed
153
154
155
156
157
158
159
160
161
162
163
        processor = self.ctx.get_hf_processor(
            Qwen2VLProcessor,
            **kwargs,
        )
        processor.image_token = IMAGE_TOKEN
        processor.video_token = "<|video_pad|>"
        return processor


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
164
165
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
Roger Wang's avatar
Roger Wang committed
166
167
168
    return torch.cat((-x2, x1), dim=-1)


169
170
171
def apply_rotary_pos_emb_vision(
    tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    orig_dtype = tensor.dtype
    tensor = tensor.float()

    cos = freqs.cos()
    sin = freqs.sin()

    cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
    sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()

    output = (tensor * cos) + (rotate_half(tensor) * sin)

    output = output.to(orig_dtype)

    return output


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
191
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
Roger Wang's avatar
Roger Wang committed
192
193
194
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
195
196
197
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
Roger Wang's avatar
Roger Wang committed
198
199
200
201
202
203
204
205
206
207
208
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class PatchMerger(nn.Module):
    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        pre_norm="layernorm",
209
210
        prefix: str = "",
        use_data_parallel: bool = False,
Roger Wang's avatar
Roger Wang committed
211
212
213
214
215
216
217
218
219
220
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.pre_norm = pre_norm
        if self.pre_norm == "layernorm":
            self.ln_q = LayerNorm(context_dim, eps=1e-6)
        elif self.pre_norm == "rmsnorm":
            self.ln_q = RMSNorm(context_dim, eps=1e-6)

        self.mlp = nn.Sequential(
221
222
223
224
225
226
227
228
            ColumnParallelLinear(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.0",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
229
            nn.GELU(),
230
231
232
233
234
235
236
237
            RowParallelLinear(
                self.hidden_size,
                dim,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.2",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
238
239
240
241
242
243
244
245
246
247
248
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pre_norm:
            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        else:
            x = self.mlp(x.view(-1, self.hidden_size))
        return x


class DotsVisionAttention(nn.Module):
249
250
251
252
253
254
255
    def __init__(
        self,
        config,
        dim: int,
        num_heads: int = 16,
        bias: bool = True,
        *,
256
        quant_config: QuantizationConfig | None = None,
257
258
        prefix: str = "",
        use_data_parallel: bool = False,
259
        attn_backend_override: AttentionBackendEnum | None = None,
260
    ) -> None:
Roger Wang's avatar
Roger Wang committed
261
262
263
        super().__init__()

        self.embed_dim = dim
264
265
266
267
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
268
269
        self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
270
271
            num_heads, self.tp_size
        )
Roger Wang's avatar
Roger Wang committed
272
        # qkv/proj follow Qwen2-VL style; bias controlled by arg
273
274
275
276
277
278
279
        self.qkv = QKVParallelLinear(
            hidden_size=dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
280
281
282
283
284
285
286
287
288
289
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
290
        # Select attention backend
291
        self.attn_backend = get_vit_attn_backend(
292
293
294
            self.hidden_size_per_attention_head,
            torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
295
        )
Roger Wang's avatar
Roger Wang committed
296
        self.use_upstream_fa = False
297

298
299
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
300
301
                self.attn_backend,
                self.use_upstream_fa,
302
                attn_backend_override=attn_backend_override,
303
            )
304
        )
Roger Wang's avatar
Roger Wang committed
305
        if self.attn_backend not in {
306
307
308
309
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.TORCH_SDPA,
            AttentionBackendEnum.XFORMERS,
            AttentionBackendEnum.ROCM_AITER_FA,
Roger Wang's avatar
Roger Wang committed
310
311
        }:
            raise RuntimeError(
312
313
                f"Unsupported vision attention backend: {self.attn_backend}"
            )
Roger Wang's avatar
Roger Wang committed
314
        self.is_flash_attn_backend = self.attn_backend in {
315
316
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
Roger Wang's avatar
Roger Wang committed
317
318
319
320
321
322
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
323
        rotary_pos_emb: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
324
        *,
325
326
        max_seqlen: int | None = None,
        seqlens: list[int] | None = None,
Roger Wang's avatar
Roger Wang committed
327
328
329
330
    ) -> torch.Tensor:
        # [S, C] -> [S, B=1, C]
        x = hidden_states.unsqueeze(1)
        x, _ = self.qkv(x)
331
        q, k, v = Qwen2VisionAttention.split_qkv(self, x)
Roger Wang's avatar
Roger Wang committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        bs = q.shape[1]
        # [S,B,H,D] -> [B,S,H,D]
        q = q.permute(1, 0, 2, 3).contiguous()
        k = k.permute(1, 0, 2, 3).contiguous()
        v = v.permute(1, 0, 2, 3).contiguous()

        if rotary_pos_emb is not None:
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)

        if self.is_flash_attn_backend:
            q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
            k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
            v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            output = self.flash_attn_varlen_func(
                q_,
                k_,
                v_,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )
            context_layer = output.view(
                bs,
                -1,
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
364
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
Roger Wang's avatar
Roger Wang committed
365
366
367
368
369
370
371
            outputs = []
            for i in range(1, len(cu_seqlens)):
                s = int(cu_seqlens[i - 1])
                e = int(cu_seqlens[i])
                q_i = q[:, s:e].permute(0, 2, 1, 3)
                k_i = k[:, s:e].permute(0, 2, 1, 3)
                v_i = v[:, s:e].permute(0, 2, 1, 3)
372
                out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
Roger Wang's avatar
Roger Wang committed
373
374
375
                out_i = out_i.permute(0, 2, 1, 3)
                outputs.append(out_i)
            context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
376
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
Roger Wang's avatar
Roger Wang committed
377
378
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask
379
380
381
382

            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
Roger Wang's avatar
Roger Wang committed
383
            context_layer = xops.memory_efficient_attention_forward(
384
385
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
Roger Wang's avatar
Roger Wang committed
386
387
388
389
390
391
392
393
394
395
396
        else:
            raise RuntimeError("Unsupported attention backend")

        # [B,S,H,D] -> [S,B,H*D] -> [S, C]
        context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
        context_layer = context_layer.view(context_layer.shape[0], bs, -1)
        out, _ = self.proj(context_layer)
        return out.squeeze(1)


class DotsSwiGLUFFN(nn.Module):
397
398
399
400
    def __init__(
        self,
        config,
        *,
401
        quant_config: QuantizationConfig | None = None,
402
403
404
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
Roger Wang's avatar
Roger Wang committed
405
406
407
408
409
410
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.embed_dim
        bias = config.use_bias

        # Referenced aimv2.py AIMv2SwiGLUFFN
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
427
428
429
430
431
432
433
434
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x

435
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        stacked_params_mapping = [
            ("fc13", "fc1", 0),
            ("fc13", "fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
460
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
461
462
463
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
Roger Wang's avatar
Roger Wang committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482


class DotsPatchEmbed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_channels = config.num_channels
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.proj = nn.Conv2d(
            config.num_channels,
            config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
483
484
485
486
487
488
489
        x = x.view(
            -1,
            self.num_channels,
            self.temporal_patch_size,
            self.patch_size,
            self.patch_size,
        )[:, :, 0]
Roger Wang's avatar
Roger Wang committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        x = self.proj(x).view(-1, self.embed_dim)
        x = self.norm(x)
        return x


class DotsViTPreprocessor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_h = config.patch_size
        self.patch_w = config.patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.patchifier = DotsPatchEmbed(config)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
        tokens = self.patchifier(x, grid_thw)
        return tokens


class DotsVisionBlock(nn.Module):
510
511
512
513
    def __init__(
        self,
        config,
        *,
514
        quant_config: QuantizationConfig | None = None,
515
516
        prefix: str = "",
        use_data_parallel: bool = False,
517
        attn_backend_override: AttentionBackendEnum | None = None,
518
    ):
Roger Wang's avatar
Roger Wang committed
519
520
        super().__init__()

521
522
523
524
525
526
527
528
        self.attn = DotsVisionAttention(
            config,
            config.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
529
            attn_backend_override=attn_backend_override,
530
        )
Roger Wang's avatar
Roger Wang committed
531
        self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
532
533
534
535
536
537
        self.mlp = DotsSwiGLUFFN(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
538
539
        self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

540
541
542
543
544
545
    def forward(
        self,
        hidden_states: torch.Tensor,
        *,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
546
547
        max_seqlen: int | None = None,
        seqlens: list[int] | None = None,
548
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
549
550
551
552
553
554
555
556
557
558
559
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


560
class DotsVisionTransformer(nn.Module):
Roger Wang's avatar
Roger Wang committed
561
562
563
    def __init__(
        self,
        config: DotsVisionConfig,
564
        quant_config: QuantizationConfig | None = None,
Roger Wang's avatar
Roger Wang committed
565
        *,
566
567
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
Roger Wang's avatar
Roger Wang committed
568
        prefix: str = "",
569
        use_data_parallel: bool = False,
570
        attn_backend_override: AttentionBackendEnum | None = None,
Roger Wang's avatar
Roger Wang committed
571
    ) -> None:
572
        super().__init__()
Roger Wang's avatar
Roger Wang committed
573
574
575
576
577
578
579
580
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = DotsViTPreprocessor(config)

        head_dim = config.embed_dim // config.num_attention_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
        self.attn_backend = get_vit_attn_backend(
581
582
583
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
584
        )
585
586
587
        if (
            self.attn_backend != AttentionBackendEnum.FLASH_ATTN
            and check_upstream_fa_availability(torch.get_default_dtype())
588
        ):
589
            self.attn_backend = AttentionBackendEnum.FLASH_ATTN
590
        self.out_hidden_size = config.hidden_size
Roger Wang's avatar
Roger Wang committed
591
        # Keep blocks for compatibility with other vision towers
592
593
594
595
596
597
598
599
600
601
602
603
        num_layers = (
            config.num_hidden_layers
            if num_hidden_layers_override is None
            else num_hidden_layers_override
        )
        self.blocks = nn.ModuleList(
            [
                DotsVisionBlock(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{i}",
                    use_data_parallel=use_data_parallel,
604
                    attn_backend_override=attn_backend_override,
605
606
607
608
                )
                for i in range(num_layers)
            ]
        )
Roger Wang's avatar
Roger Wang committed
609
        if require_post_norm is None:
610
            require_post_norm = len(self.blocks) == config.num_hidden_layers
Roger Wang's avatar
Roger Wang committed
611
        if require_post_norm and self.config.post_norm:
612
            self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
Roger Wang's avatar
Roger Wang committed
613
614
615
616
617
618
619
        else:
            self.post_trunk_norm = None

        self.merger = PatchMerger(
            dim=config.hidden_size,
            context_dim=config.embed_dim,
            spatial_merge_size=config.spatial_merge_size,
620
            use_data_parallel=use_data_parallel,
Roger Wang's avatar
Roger Wang committed
621
622
623
624
625
626
627
628
629
630
        )

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.patchifier.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.patchifier.proj.weight.device

631
    def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
Roger Wang's avatar
Roger Wang committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
653
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
Roger Wang's avatar
Roger Wang committed
654
655
656

        return pos_ids

657
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
658
659
        pos_ids = self.get_pos_ids_by_grid(grid_thw)
        pos_ids = torch.cat(pos_ids, dim=0)
660
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
Roger Wang's avatar
Roger Wang committed
661
662
663
664
665
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def compute_attn_mask_seqlen(
666
        self, cu_seqlens: torch.Tensor
667
    ) -> tuple[int | None, list[int] | None]:
Roger Wang's avatar
Roger Wang committed
668
        max_seqlen, seqlens = None, None
669
        if (
670
671
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
672
        ):
Roger Wang's avatar
Roger Wang committed
673
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
674
        elif self.attn_backend == AttentionBackendEnum.XFORMERS:
Roger Wang's avatar
Roger Wang committed
675
676
677
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

678
679
680
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
    ) -> torch.Tensor:
681
682
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

683
        # Convert grid_thw to tensor (always expecting list format now)
684
        grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
Roger Wang's avatar
Roger Wang committed
685
686
687
688
        hidden_states = hidden_states.to(self.dtype)
        hidden_states = self.patch_embed(hidden_states, grid_thw)

        cu_seqlens = torch.repeat_interleave(
689
690
691
692
693
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
694
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
Roger Wang's avatar
Roger Wang committed
695
696
697

        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
        for blk in self.blocks:
698
699
700
701
702
703
704
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
Roger Wang's avatar
Roger Wang committed
705
706
707
708
709
710
711
712
713
714
715
716
717

        if self.post_trunk_norm is not None:
            hidden_states = self.post_trunk_norm(hidden_states)

        hidden_states = self.merger(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=DotsOCRProcessingInfo,
    dummy_inputs=DotsOCRDummyInputsBuilder,
)
718
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
719
720
    merge_by_field_config = True

Roger Wang's avatar
Roger Wang committed
721
722
723
724
725
726
727
728
729
730
731
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".attn.qkv_proj.": ".attn.qkv.",
            ".attn.out_proj.": ".attn.proj.",
        },
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        },
    )

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
        ".attn.qkv": [".attn.qkv"],
        "fc13": ["fc1", "fc3"],
    }
    supports_encoder_tp_data = True

Roger Wang's avatar
Roger Wang committed
747
    @classmethod
748
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Roger Wang's avatar
Roger Wang committed
749
750
751
752
753
754
755
756
        if modality.startswith("image"):
            return "<|img|><|imgpad|><|endofimg|>"

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.config: DotsOCRConfig = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
757
758
        multimodal_config = vllm_config.model_config.multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Roger Wang's avatar
Roger Wang committed
759
760
761
762
763
        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config
764
765
766
767
768
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
Roger Wang's avatar
Roger Wang committed
769
770
771
772
        self.vision_tower = DotsVisionTransformer(
            vision_config,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "vision_tower"),
773
            use_data_parallel=self.use_data_parallel,
774
            attn_backend_override=attn_backend_override,
775
        )
Roger Wang's avatar
Roger Wang committed
776
777
778
779
780
781
782
        self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=self.config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )

783
784
785
786
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

Roger Wang's avatar
Roger Wang committed
787
    def _parse_and_validate_image_input(
788
        self, **kwargs: object
789
    ) -> DotsOCRImageInputs | None:
Roger Wang's avatar
Roger Wang committed
790
791
792
793
794
795
796
797
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
798
799
800
801
802
            return DotsOCRImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
803
804

        if image_embeds is not None:
805
806
807
808
809
            return DotsOCRImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
810
811

    def _process_image_input(
812
813
        self, image_input: DotsOCRImageInputs
    ) -> tuple[torch.Tensor, ...]:
Roger Wang's avatar
Roger Wang committed
814
815
816
817
818
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
819
            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
Roger Wang's avatar
Roger Wang committed
820
        else:
821
            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
822
823
824
825
826
827
828
829
830

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.vision_tower,
                    pixel_values,
                    grid_thw_list,
                    rope_type="rope_3d",
                )
            else:
831
                image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
832
833
                    :, : self.config.hidden_size
                ]
Roger Wang's avatar
Roger Wang committed
834
835
836

        # Split concatenated embeddings for each image item.
        merge_size = self.vision_tower.spatial_merge_size
837
838
839
840
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
Roger Wang's avatar
Roger Wang committed
841
842
843
844
845
846

        return image_embeds.split(sizes)

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

847
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
Roger Wang's avatar
Roger Wang committed
848
849
850
851
852
853
854
855
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
856
        input_ids: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
857
        positions: torch.Tensor,
858
859
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
860
        **kwargs,
861
    ) -> torch.Tensor | IntermediateTensors:
Roger Wang's avatar
Roger Wang committed
862
863
        if intermediate_tensors is not None:
            inputs_embeds = None
864
        elif inputs_embeds is None:
865
866
            vision_embeddings = self.embed_multimodal(**kwargs)
            inputs_embeds = self.embed_input_ids(
867
868
869
870
871
                input_ids,
                vision_embeddings,
                is_multimodal=input_ids == self.config.image_token_id,
            )
            input_ids = None
Roger Wang's avatar
Roger Wang committed
872
873
874
875
876
877
878
879
880
881
882
883
884

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
885
    ) -> torch.Tensor | None:
Roger Wang's avatar
Roger Wang committed
886
887
        return self.language_model.compute_logits(hidden_states)

888
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Roger Wang's avatar
Roger Wang committed
889
890
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
891
892
893
894
895
896
897
898
899
900

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_tower.merger",
            tower_model="vision_tower.",
        )