dots_ocr.py 27.4 KB
Newer Older
Roger Wang's avatar
Roger Wang committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
4
from typing import Annotated, Literal, TypeAlias
Roger Wang's avatar
Roger Wang committed
5
6
7
8
9
10

import torch
import torch.nn as nn
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

11
from vllm.config import MultiModalConfig, VllmConfig
12
from vllm.config.multimodal import BaseDummyOptions
13
14
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
15
16
17
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Roger Wang's avatar
Roger Wang committed
18
from vllm.model_executor.layers.activation import SiluAndMul
19
20
21
from vllm.model_executor.layers.attention.mm_encoder_attention import (
    MMEncoderAttention,
)
22
from vllm.model_executor.layers.conv import Conv2dLayer
Roger Wang's avatar
Roger Wang committed
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Roger Wang's avatar
Roger Wang committed
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
32
33
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
36
37
38
39
40
from vllm.model_executor.models.interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
41
from vllm.model_executor.models.module_mapping import MultiModelKeys
Roger Wang's avatar
Roger Wang committed
42
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
43
from vllm.model_executor.models.qwen2_vl import (
44
    Qwen2VisionAttention,
45
46
47
48
49
50
51
52
53
54
    Qwen2VLDummyInputsBuilder,
    Qwen2VLMultiModalProcessor,
    Qwen2VLProcessingInfo,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Roger Wang's avatar
Roger Wang committed
55
56
57
58
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.sequence import IntermediateTensors
59
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61
from vllm.v1.attention.backends.registry import AttentionBackendEnum
Roger Wang's avatar
Roger Wang committed
62

63
64
from .vision import run_dp_sharded_mrope_vision_model

Roger Wang's avatar
Roger Wang committed
65
66
67
IMAGE_TOKEN = "<|imgpad|>"


68
69
70
71
72
73
74
75
class DotsOCRImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    """
76

77
    type: Literal["pixel_values"]
Roger Wang's avatar
Roger Wang committed
78

79
80
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
81
82


83
class DotsOCRImageEmbeddingInputs(TensorSchema):
Roger Wang's avatar
Roger Wang committed
84
    """
85
86
87
88
89
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
    """
90

91
    type: Literal["image_embeds"]
Roger Wang's avatar
Roger Wang committed
92

93
94
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
95
96


97
DotsOCRImageInputs: TypeAlias = DotsOCRImagePixelInputs | DotsOCRImageEmbeddingInputs
Roger Wang's avatar
Roger Wang committed
98
99
100
101
102
103
104
105
106
107
108


class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Roger Wang's avatar
Roger Wang committed
110
111
112
113
114
115
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features(  # noqa: E501
        )

116
117
        image_overrides = mm_options.get("image") if mm_options else None

Roger Wang's avatar
Roger Wang committed
118
        return {
119
120
121
122
123
124
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
Roger Wang's avatar
Roger Wang committed
125
126
127
128
129
130
        }


class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> DotsOCRConfig:
        config = self.ctx.get_hf_config()
131
        if not config.__class__.__name__ == "DotsOCRConfig":
Roger Wang's avatar
Roger Wang committed
132
133
            raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")

134
        if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
Roger Wang's avatar
Roger Wang committed
135
136
137
138
            config.vision_config = DotsVisionConfig(**config.vision_config)

        return config

139
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Roger Wang's avatar
Roger Wang committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        return {"image": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        return {"image": max_image_tokens}

    def get_hf_processor(
        self,
        **kwargs: object,
    ) -> Qwen2VLProcessor:
154
        self.get_tokenizer().image_token = IMAGE_TOKEN  # Ensure image token is set
Roger Wang's avatar
Roger Wang committed
155
156
157
158
159
160
161
162
163
164
165
166
        processor = self.ctx.get_hf_processor(
            Qwen2VLProcessor,
            **kwargs,
        )
        processor.image_token = IMAGE_TOKEN
        processor.video_token = "<|video_pad|>"
        return processor


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
167
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
Roger Wang's avatar
Roger Wang committed
168
169
170
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
171
172
173
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
Roger Wang's avatar
Roger Wang committed
174
175
176
177
178
179
180
181
182
183
184
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class PatchMerger(nn.Module):
    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        pre_norm="layernorm",
185
186
        prefix: str = "",
        use_data_parallel: bool = False,
Roger Wang's avatar
Roger Wang committed
187
188
189
190
191
192
193
194
195
196
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.pre_norm = pre_norm
        if self.pre_norm == "layernorm":
            self.ln_q = LayerNorm(context_dim, eps=1e-6)
        elif self.pre_norm == "rmsnorm":
            self.ln_q = RMSNorm(context_dim, eps=1e-6)

        self.mlp = nn.Sequential(
197
198
199
200
201
202
203
204
            ColumnParallelLinear(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.0",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
205
            nn.GELU(),
206
207
208
209
210
211
212
213
            RowParallelLinear(
                self.hidden_size,
                dim,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.2",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
214
215
216
217
218
219
220
221
222
223
224
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pre_norm:
            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        else:
            x = self.mlp(x.view(-1, self.hidden_size))
        return x


class DotsVisionAttention(nn.Module):
225
226
227
228
229
230
231
    def __init__(
        self,
        config,
        dim: int,
        num_heads: int = 16,
        bias: bool = True,
        *,
232
        quant_config: QuantizationConfig | None = None,
233
        multimodal_config: MultiModalConfig | None = None,
234
235
        prefix: str = "",
    ) -> None:
Roger Wang's avatar
Roger Wang committed
236
        super().__init__()
237
238
239
240
241
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
Roger Wang's avatar
Roger Wang committed
242
243

        self.embed_dim = dim
244
245
246
247
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
248
249
        self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
250
251
            num_heads, self.tp_size
        )
Roger Wang's avatar
Roger Wang committed
252
        # qkv/proj follow Qwen2-VL style; bias controlled by arg
253
254
255
256
257
258
259
        self.qkv = QKVParallelLinear(
            hidden_size=dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
260
261
262
263
264
265
266
267
268
269
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
270

271
272
273
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
            head_size=self.hidden_size_per_attention_head,
274
            scale=self.hidden_size_per_attention_head**-0.5,
275
276
            multimodal_config=multimodal_config,
            prefix=f"{prefix}.attn",
277
        )
Roger Wang's avatar
Roger Wang committed
278

279
280
281
282
283
        self.apply_rotary_emb = ApplyRotaryEmb(
            enforce_enable=True,
            enable_fp32_compute=True,
        )

Roger Wang's avatar
Roger Wang committed
284
285
286
287
    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
288
        rotary_pos_emb: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
289
        *,
290
        max_seqlen: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
291
292
293
294
    ) -> torch.Tensor:
        # [S, C] -> [S, B=1, C]
        x = hidden_states.unsqueeze(1)
        x, _ = self.qkv(x)
295
        q, k, v = Qwen2VisionAttention.split_qkv(self, x)
Roger Wang's avatar
Roger Wang committed
296
297
298
299
300
301
302
303
        bs = q.shape[1]
        # [S,B,H,D] -> [B,S,H,D]
        q = q.permute(1, 0, 2, 3).contiguous()
        k = k.permute(1, 0, 2, 3).contiguous()
        v = v.permute(1, 0, 2, 3).contiguous()

        if rotary_pos_emb is not None:
            qk_concat = torch.cat([q, k], dim=0)
304
305
306
307
308
            qk_rotated = self.apply_rotary_emb(
                qk_concat,
                rotary_pos_emb.cos(),
                rotary_pos_emb.sin(),
            )
Roger Wang's avatar
Roger Wang committed
309
310
            q, k = torch.chunk(qk_rotated, 2, dim=0)

311
312
313
314
315
316
317
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
Roger Wang's avatar
Roger Wang committed
318
319
320
321
322
323
324
325
326

        # [B,S,H,D] -> [S,B,H*D] -> [S, C]
        context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
        context_layer = context_layer.view(context_layer.shape[0], bs, -1)
        out, _ = self.proj(context_layer)
        return out.squeeze(1)


class DotsSwiGLUFFN(nn.Module):
327
328
329
330
    def __init__(
        self,
        config,
        *,
331
        quant_config: QuantizationConfig | None = None,
332
        multimodal_config: MultiModalConfig | None = None,
333
334
        prefix: str = "",
    ):
Roger Wang's avatar
Roger Wang committed
335
336
337
338
339
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.embed_dim
        bias = config.use_bias

340
341
342
343
344
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
Roger Wang's avatar
Roger Wang committed
345
        # Referenced aimv2.py AIMv2SwiGLUFFN
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
362
363
364
365
366
367
368
369
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x

370
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        stacked_params_mapping = [
            ("fc13", "fc1", 0),
            ("fc13", "fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
395
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
396
397
398
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
Roger Wang's avatar
Roger Wang committed
399
400
401
402
403
404
405
406
407
408


class DotsPatchEmbed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_channels = config.num_channels
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.embed_dim = config.embed_dim
        self.config = config
409
        self.proj = Conv2dLayer(
Roger Wang's avatar
Roger Wang committed
410
411
412
413
414
415
416
417
            config.num_channels,
            config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
418
419
420
421
422
423
424
        x = x.view(
            -1,
            self.num_channels,
            self.temporal_patch_size,
            self.patch_size,
            self.patch_size,
        )[:, :, 0]
Roger Wang's avatar
Roger Wang committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        x = self.proj(x).view(-1, self.embed_dim)
        x = self.norm(x)
        return x


class DotsViTPreprocessor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_h = config.patch_size
        self.patch_w = config.patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.patchifier = DotsPatchEmbed(config)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
        tokens = self.patchifier(x, grid_thw)
        return tokens


class DotsVisionBlock(nn.Module):
445
446
447
448
    def __init__(
        self,
        config,
        *,
449
        quant_config: QuantizationConfig | None = None,
450
        multimodal_config: MultiModalConfig | None = None,
451
452
        prefix: str = "",
    ):
Roger Wang's avatar
Roger Wang committed
453
454
        super().__init__()

455
456
457
458
459
460
        self.attn = DotsVisionAttention(
            config,
            config.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.use_bias,
            quant_config=quant_config,
461
            multimodal_config=multimodal_config,
462
463
            prefix=f"{prefix}.attn",
        )
Roger Wang's avatar
Roger Wang committed
464
        self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
465
466
467
        self.mlp = DotsSwiGLUFFN(
            config,
            quant_config=quant_config,
468
            multimodal_config=multimodal_config,
469
470
            prefix=f"{prefix}.mlp",
        )
Roger Wang's avatar
Roger Wang committed
471
472
        self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

473
474
475
476
477
478
    def forward(
        self,
        hidden_states: torch.Tensor,
        *,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
479
        max_seqlen: int | None = None,
480
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
481
482
483
484
485
486
487
488
489
490
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


491
class DotsVisionTransformer(nn.Module):
Roger Wang's avatar
Roger Wang committed
492
493
494
    def __init__(
        self,
        config: DotsVisionConfig,
495
        quant_config: QuantizationConfig | None = None,
496
        multimodal_config: MultiModalConfig | None = None,
Roger Wang's avatar
Roger Wang committed
497
        *,
498
499
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
Roger Wang's avatar
Roger Wang committed
500
501
        prefix: str = "",
    ) -> None:
502
        super().__init__()
Roger Wang's avatar
Roger Wang committed
503
504
505
506
507
508
509
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = DotsViTPreprocessor(config)

        head_dim = config.embed_dim // config.num_attention_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
510
511
512
513
514
        attn_backend_override = (
            multimodal_config.mm_encoder_attn_backend
            if multimodal_config is not None
            else None
        )
Roger Wang's avatar
Roger Wang committed
515
        self.attn_backend = get_vit_attn_backend(
516
517
518
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
            attn_backend_override=attn_backend_override,
519
        )
520
        self.out_hidden_size = config.hidden_size
Roger Wang's avatar
Roger Wang committed
521
        # Keep blocks for compatibility with other vision towers
522
523
524
525
526
527
528
529
530
531
        num_layers = (
            config.num_hidden_layers
            if num_hidden_layers_override is None
            else num_hidden_layers_override
        )
        self.blocks = nn.ModuleList(
            [
                DotsVisionBlock(
                    config,
                    quant_config=quant_config,
532
                    multimodal_config=multimodal_config,
533
534
535
536
537
                    prefix=f"{prefix}.blocks.{i}",
                )
                for i in range(num_layers)
            ]
        )
Roger Wang's avatar
Roger Wang committed
538
        if require_post_norm is None:
539
            require_post_norm = len(self.blocks) == config.num_hidden_layers
Roger Wang's avatar
Roger Wang committed
540
        if require_post_norm and self.config.post_norm:
541
            self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
Roger Wang's avatar
Roger Wang committed
542
543
544
        else:
            self.post_trunk_norm = None

545
546
547
548
549
        use_data_parallel = (
            multimodal_config.mm_encoder_tp_mode == "data"
            if multimodal_config
            else False
        )
Roger Wang's avatar
Roger Wang committed
550
551
552
553
        self.merger = PatchMerger(
            dim=config.hidden_size,
            context_dim=config.embed_dim,
            spatial_merge_size=config.spatial_merge_size,
554
            use_data_parallel=use_data_parallel,
Roger Wang's avatar
Roger Wang committed
555
556
557
558
559
560
561
562
563
564
        )

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.patchifier.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.patchifier.proj.weight.device

565
    def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
Roger Wang's avatar
Roger Wang committed
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
587
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
Roger Wang's avatar
Roger Wang committed
588
589
590

        return pos_ids

591
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
592
593
        pos_ids = self.get_pos_ids_by_grid(grid_thw)
        pos_ids = torch.cat(pos_ids, dim=0)
594
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
Roger Wang's avatar
Roger Wang committed
595
596
597
598
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

599
600
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
601
        if (
602
603
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
604
        ):
605
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
606
        return max_seqlen
Roger Wang's avatar
Roger Wang committed
607

608
609
610
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
    ) -> torch.Tensor:
611
612
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

613
        # Convert grid_thw to tensor (always expecting list format now)
614
        grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
Roger Wang's avatar
Roger Wang committed
615
616
617
618
        hidden_states = hidden_states.to(self.dtype)
        hidden_states = self.patch_embed(hidden_states, grid_thw)

        cu_seqlens = torch.repeat_interleave(
619
620
621
622
623
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
624
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
Roger Wang's avatar
Roger Wang committed
625

626
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
Roger Wang's avatar
Roger Wang committed
627
        for blk in self.blocks:
628
629
630
631
632
633
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
            )
Roger Wang's avatar
Roger Wang committed
634
635
636
637
638
639
640
641
642
643
644
645
646

        if self.post_trunk_norm is not None:
            hidden_states = self.post_trunk_norm(hidden_states)

        hidden_states = self.merger(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=DotsOCRProcessingInfo,
    dummy_inputs=DotsOCRDummyInputsBuilder,
)
647
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
Roger Wang's avatar
Roger Wang committed
648
649
650
651
652
653
654
655
656
657
658
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".attn.qkv_proj.": ".attn.qkv.",
            ".attn.out_proj.": ".attn.proj.",
        },
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        },
    )

659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
        ".attn.qkv": [".attn.qkv"],
        "fc13": ["fc1", "fc3"],
    }
    supports_encoder_tp_data = True

Roger Wang's avatar
Roger Wang committed
674
    @classmethod
675
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Roger Wang's avatar
Roger Wang committed
676
677
678
679
680
681
682
683
        if modality.startswith("image"):
            return "<|img|><|imgpad|><|endofimg|>"

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.config: DotsOCRConfig = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
684
685
        multimodal_config = vllm_config.model_config.multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Roger Wang's avatar
Roger Wang committed
686
687
688
689
690
        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config
691

692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = DotsVisionTransformer(
                vision_config,
                quant_config=self.quant_config,
                multimodal_config=multimodal_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )

        with self._mark_language_model(vllm_config):
            self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=self.config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
Roger Wang's avatar
Roger Wang committed
707

708
709
710
711
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

Roger Wang's avatar
Roger Wang committed
712
    def _parse_and_validate_image_input(
713
        self, **kwargs: object
714
    ) -> DotsOCRImageInputs | None:
Roger Wang's avatar
Roger Wang committed
715
716
717
718
719
720
721
722
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
723
724
725
726
727
            return DotsOCRImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
728
729

        if image_embeds is not None:
730
731
732
733
734
            return DotsOCRImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
735
736

    def _process_image_input(
737
738
        self, image_input: DotsOCRImageInputs
    ) -> tuple[torch.Tensor, ...]:
Roger Wang's avatar
Roger Wang committed
739
740
741
742
743
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
744
            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
Roger Wang's avatar
Roger Wang committed
745
        else:
746
            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
747
748
749
750
751
752
753
754
755

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.vision_tower,
                    pixel_values,
                    grid_thw_list,
                    rope_type="rope_3d",
                )
            else:
756
                image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
757
758
                    :, : self.config.hidden_size
                ]
Roger Wang's avatar
Roger Wang committed
759
760
761

        # Split concatenated embeddings for each image item.
        merge_size = self.vision_tower.spatial_merge_size
762
763
764
765
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
Roger Wang's avatar
Roger Wang committed
766
767
768

        return image_embeds.split(sizes)

769
770
771
772
773
774
775
776
    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        merge_size = self.vision_tower.spatial_merge_size
        return num_image_tokens * (merge_size**2)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        merge_size = self.vision_tower.spatial_merge_size
        return num_vision_tokens // (merge_size**2)

777
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
Roger Wang's avatar
Roger Wang committed
778
779
780
781
782
783
784
785
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
786
        input_ids: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
787
        positions: torch.Tensor,
788
789
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
790
        **kwargs,
791
    ) -> torch.Tensor | IntermediateTensors:
Roger Wang's avatar
Roger Wang committed
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
807
    ) -> torch.Tensor | None:
Roger Wang's avatar
Roger Wang committed
808
809
        return self.language_model.compute_logits(hidden_states)

810
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Roger Wang's avatar
Roger Wang committed
811
812
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
813
814
815
816
817
818
819
820
821
822

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_tower.merger",
            tower_model="vision_tower.",
        )