dots_ocr.py 26.5 KB
Newer Older
Roger Wang's avatar
Roger Wang committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
4
from typing import Annotated, Literal, TypeAlias
Roger Wang's avatar
Roger Wang committed
5
6
7
8
9
10

import torch
import torch.nn as nn
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

11
from vllm.config import VllmConfig
12
from vllm.config.multimodal import BaseDummyOptions
13
14
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
15
16
17
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Roger Wang's avatar
Roger Wang committed
18
from vllm.model_executor.layers.activation import SiluAndMul
19
from vllm.model_executor.layers.attention import (
20
21
    MMEncoderAttention,
)
22
from vllm.model_executor.layers.conv import Conv2dLayer
Roger Wang's avatar
Roger Wang committed
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Roger Wang's avatar
Roger Wang committed
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
32
33
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
36
37
38
39
40
from vllm.model_executor.models.interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
41
from vllm.model_executor.models.module_mapping import MultiModelKeys
Roger Wang's avatar
Roger Wang committed
42
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
43
from vllm.model_executor.models.qwen2_vl import (
44
    Qwen2VisionAttention,
45
46
47
48
49
50
51
52
53
54
    Qwen2VLDummyInputsBuilder,
    Qwen2VLMultiModalProcessor,
    Qwen2VLProcessingInfo,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Roger Wang's avatar
Roger Wang committed
55
56
57
58
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.sequence import IntermediateTensors
59
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61
from vllm.v1.attention.backends.registry import AttentionBackendEnum
Roger Wang's avatar
Roger Wang committed
62

63
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
64

Roger Wang's avatar
Roger Wang committed
65
66
67
IMAGE_TOKEN = "<|imgpad|>"


68
69
70
71
72
73
74
75
class DotsOCRImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    """
76

77
    type: Literal["pixel_values"]
Roger Wang's avatar
Roger Wang committed
78

79
80
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
81
82


83
class DotsOCRImageEmbeddingInputs(TensorSchema):
Roger Wang's avatar
Roger Wang committed
84
    """
85
86
87
88
89
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
    """
90

91
    type: Literal["image_embeds"]
Roger Wang's avatar
Roger Wang committed
92

93
94
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
95
96


97
DotsOCRImageInputs: TypeAlias = DotsOCRImagePixelInputs | DotsOCRImageEmbeddingInputs
Roger Wang's avatar
Roger Wang committed
98
99
100
101
102
103
104
105
106
107
108


class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
110
        mm_processor_kwargs: Mapping[str, object] | None = None,
Roger Wang's avatar
Roger Wang committed
111
112
113
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

114
        mm_processor_kwargs = mm_processor_kwargs or {}
Roger Wang's avatar
Roger Wang committed
115
        target_width, target_height = self.info.get_image_size_with_most_features(  # noqa: E501
116
            mm_processor_kwargs.get("max_pixels", None)
Roger Wang's avatar
Roger Wang committed
117
118
        )

119
120
        image_overrides = mm_options.get("image") if mm_options else None

Roger Wang's avatar
Roger Wang committed
121
        return {
122
123
124
125
126
127
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
Roger Wang's avatar
Roger Wang committed
128
129
130
131
132
133
        }


class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> DotsOCRConfig:
        config = self.ctx.get_hf_config()
134
        if not config.__class__.__name__ == "DotsOCRConfig":
Roger Wang's avatar
Roger Wang committed
135
136
            raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")

137
        if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
Roger Wang's avatar
Roger Wang committed
138
139
140
141
            config.vision_config = DotsVisionConfig(**config.vision_config)

        return config

142
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Roger Wang's avatar
Roger Wang committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        return {"image": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        return {"image": max_image_tokens}

    def get_hf_processor(
        self,
        **kwargs: object,
    ) -> Qwen2VLProcessor:
157
        self.get_tokenizer().image_token = IMAGE_TOKEN  # Ensure image token is set
Roger Wang's avatar
Roger Wang committed
158
159
160
161
162
163
164
165
166
167
168
169
        processor = self.ctx.get_hf_processor(
            Qwen2VLProcessor,
            **kwargs,
        )
        processor.image_token = IMAGE_TOKEN
        processor.video_token = "<|video_pad|>"
        return processor


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
170
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
Roger Wang's avatar
Roger Wang committed
171
172
173
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
174
175
176
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
Roger Wang's avatar
Roger Wang committed
177
178
179
180
181
182
183
184
185
186
187
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class PatchMerger(nn.Module):
    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        pre_norm="layernorm",
188
        prefix: str = "",
Roger Wang's avatar
Roger Wang committed
189
190
    ) -> None:
        super().__init__()
191
        use_data_parallel = is_vit_use_data_parallel()
Roger Wang's avatar
Roger Wang committed
192
193
194
195
196
197
198
199
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.pre_norm = pre_norm
        if self.pre_norm == "layernorm":
            self.ln_q = LayerNorm(context_dim, eps=1e-6)
        elif self.pre_norm == "rmsnorm":
            self.ln_q = RMSNorm(context_dim, eps=1e-6)

        self.mlp = nn.Sequential(
200
201
202
203
204
205
206
207
            ColumnParallelLinear(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.0",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
208
            nn.GELU(),
209
210
211
212
213
214
215
216
            RowParallelLinear(
                self.hidden_size,
                dim,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.2",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
217
218
219
220
221
222
223
224
225
226
227
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pre_norm:
            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        else:
            x = self.mlp(x.view(-1, self.hidden_size))
        return x


class DotsVisionAttention(nn.Module):
228
229
230
231
232
233
234
    def __init__(
        self,
        config,
        dim: int,
        num_heads: int = 16,
        bias: bool = True,
        *,
235
        quant_config: QuantizationConfig | None = None,
236
237
        prefix: str = "",
    ) -> None:
Roger Wang's avatar
Roger Wang committed
238
        super().__init__()
239
        use_data_parallel = is_vit_use_data_parallel()
Roger Wang's avatar
Roger Wang committed
240
241

        self.embed_dim = dim
242
243
244
245
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
246
247
        self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
248
249
            num_heads, self.tp_size
        )
Roger Wang's avatar
Roger Wang committed
250
        # qkv/proj follow Qwen2-VL style; bias controlled by arg
251
252
253
254
255
256
257
        self.qkv = QKVParallelLinear(
            hidden_size=dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
258
259
260
261
262
263
264
265
266
267
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
268

269
270
271
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
            head_size=self.hidden_size_per_attention_head,
272
            scale=self.hidden_size_per_attention_head**-0.5,
273
            prefix=f"{prefix}.attn",
274
        )
Roger Wang's avatar
Roger Wang committed
275

276
277
278
279
280
        self.apply_rotary_emb = ApplyRotaryEmb(
            enforce_enable=True,
            enable_fp32_compute=True,
        )

Roger Wang's avatar
Roger Wang committed
281
282
283
284
    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
285
        rotary_pos_emb: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
286
        *,
287
        max_seqlen: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
288
289
290
291
    ) -> torch.Tensor:
        # [S, C] -> [S, B=1, C]
        x = hidden_states.unsqueeze(1)
        x, _ = self.qkv(x)
292
        q, k, v = Qwen2VisionAttention.split_qkv(self, x)
Roger Wang's avatar
Roger Wang committed
293
294
295
296
297
298
299
300
        bs = q.shape[1]
        # [S,B,H,D] -> [B,S,H,D]
        q = q.permute(1, 0, 2, 3).contiguous()
        k = k.permute(1, 0, 2, 3).contiguous()
        v = v.permute(1, 0, 2, 3).contiguous()

        if rotary_pos_emb is not None:
            qk_concat = torch.cat([q, k], dim=0)
301
302
303
304
305
            qk_rotated = self.apply_rotary_emb(
                qk_concat,
                rotary_pos_emb.cos(),
                rotary_pos_emb.sin(),
            )
Roger Wang's avatar
Roger Wang committed
306
307
            q, k = torch.chunk(qk_rotated, 2, dim=0)

308
309
310
311
312
313
314
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
Roger Wang's avatar
Roger Wang committed
315
316
317
318
319
320
321
322
323

        # [B,S,H,D] -> [S,B,H*D] -> [S, C]
        context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
        context_layer = context_layer.view(context_layer.shape[0], bs, -1)
        out, _ = self.proj(context_layer)
        return out.squeeze(1)


class DotsSwiGLUFFN(nn.Module):
324
325
326
327
    def __init__(
        self,
        config,
        *,
328
        quant_config: QuantizationConfig | None = None,
329
330
        prefix: str = "",
    ):
Roger Wang's avatar
Roger Wang committed
331
332
333
334
335
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.embed_dim
        bias = config.use_bias

336
        use_data_parallel = is_vit_use_data_parallel()
Roger Wang's avatar
Roger Wang committed
337
        # Referenced aimv2.py AIMv2SwiGLUFFN
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
354
355
356
357
358
359
360
361
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x

362
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
        stacked_params_mapping = [
            ("fc13", "fc1", 0),
            ("fc13", "fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
387
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
388
389
390
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
Roger Wang's avatar
Roger Wang committed
391
392
393
394
395
396
397
398
399
400


class DotsPatchEmbed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_channels = config.num_channels
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.embed_dim = config.embed_dim
        self.config = config
401
        self.proj = Conv2dLayer(
Roger Wang's avatar
Roger Wang committed
402
403
404
405
406
407
408
409
            config.num_channels,
            config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
410
411
412
413
414
415
416
        x = x.view(
            -1,
            self.num_channels,
            self.temporal_patch_size,
            self.patch_size,
            self.patch_size,
        )[:, :, 0]
Roger Wang's avatar
Roger Wang committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        x = self.proj(x).view(-1, self.embed_dim)
        x = self.norm(x)
        return x


class DotsViTPreprocessor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_h = config.patch_size
        self.patch_w = config.patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.patchifier = DotsPatchEmbed(config)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
        tokens = self.patchifier(x, grid_thw)
        return tokens


class DotsVisionBlock(nn.Module):
437
438
439
440
    def __init__(
        self,
        config,
        *,
441
        quant_config: QuantizationConfig | None = None,
442
443
        prefix: str = "",
    ):
Roger Wang's avatar
Roger Wang committed
444
445
        super().__init__()

446
447
448
449
450
451
452
453
        self.attn = DotsVisionAttention(
            config,
            config.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Roger Wang's avatar
Roger Wang committed
454
        self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
455
456
457
458
459
        self.mlp = DotsSwiGLUFFN(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
Roger Wang's avatar
Roger Wang committed
460
461
        self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

462
463
464
465
466
467
    def forward(
        self,
        hidden_states: torch.Tensor,
        *,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
468
        max_seqlen: int | None = None,
469
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
470
471
472
473
474
475
476
477
478
479
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


480
class DotsVisionTransformer(nn.Module):
Roger Wang's avatar
Roger Wang committed
481
482
483
    def __init__(
        self,
        config: DotsVisionConfig,
484
        quant_config: QuantizationConfig | None = None,
Roger Wang's avatar
Roger Wang committed
485
        *,
486
487
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
Roger Wang's avatar
Roger Wang committed
488
489
        prefix: str = "",
    ) -> None:
490
        super().__init__()
Roger Wang's avatar
Roger Wang committed
491
492
493
494
495
496
497
498
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = DotsViTPreprocessor(config)

        head_dim = config.embed_dim // config.num_attention_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
        self.attn_backend = get_vit_attn_backend(
499
500
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
501
        )
502
        self.out_hidden_size = config.hidden_size
Roger Wang's avatar
Roger Wang committed
503
        # Keep blocks for compatibility with other vision towers
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        num_layers = (
            config.num_hidden_layers
            if num_hidden_layers_override is None
            else num_hidden_layers_override
        )
        self.blocks = nn.ModuleList(
            [
                DotsVisionBlock(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{i}",
                )
                for i in range(num_layers)
            ]
        )
Roger Wang's avatar
Roger Wang committed
519
        if require_post_norm is None:
520
            require_post_norm = len(self.blocks) == config.num_hidden_layers
Roger Wang's avatar
Roger Wang committed
521
        if require_post_norm and self.config.post_norm:
522
            self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
Roger Wang's avatar
Roger Wang committed
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        else:
            self.post_trunk_norm = None

        self.merger = PatchMerger(
            dim=config.hidden_size,
            context_dim=config.embed_dim,
            spatial_merge_size=config.spatial_merge_size,
        )

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.patchifier.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.patchifier.proj.weight.device

540
    def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
Roger Wang's avatar
Roger Wang committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
562
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
Roger Wang's avatar
Roger Wang committed
563
564
565

        return pos_ids

566
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
567
568
        pos_ids = self.get_pos_ids_by_grid(grid_thw)
        pos_ids = torch.cat(pos_ids, dim=0)
569
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
Roger Wang's avatar
Roger Wang committed
570
571
572
573
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

574
575
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
576
577
578
579
580
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TRITON_ATTN,
        }:
581
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
582
        return max_seqlen
Roger Wang's avatar
Roger Wang committed
583

584
585
586
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
    ) -> torch.Tensor:
587
588
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

589
        # Convert grid_thw to tensor (always expecting list format now)
590
        grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
Roger Wang's avatar
Roger Wang committed
591
592
593
594
        hidden_states = hidden_states.to(self.dtype)
        hidden_states = self.patch_embed(hidden_states, grid_thw)

        cu_seqlens = torch.repeat_interleave(
595
596
597
598
599
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
600
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
Roger Wang's avatar
Roger Wang committed
601

602
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
Roger Wang's avatar
Roger Wang committed
603
        for blk in self.blocks:
604
605
606
607
608
609
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
            )
Roger Wang's avatar
Roger Wang committed
610
611
612
613
614
615
616
617
618
619
620
621
622

        if self.post_trunk_norm is not None:
            hidden_states = self.post_trunk_norm(hidden_states)

        hidden_states = self.merger(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=DotsOCRProcessingInfo,
    dummy_inputs=DotsOCRDummyInputsBuilder,
)
623
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
Roger Wang's avatar
Roger Wang committed
624
625
626
627
628
629
630
631
632
633
634
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".attn.qkv_proj.": ".attn.qkv.",
            ".attn.out_proj.": ".attn.proj.",
        },
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        },
    )

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
        ".attn.qkv": [".attn.qkv"],
        "fc13": ["fc1", "fc3"],
    }
    supports_encoder_tp_data = True

Roger Wang's avatar
Roger Wang committed
650
    @classmethod
651
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Roger Wang's avatar
Roger Wang committed
652
653
654
655
656
657
658
659
        if modality.startswith("image"):
            return "<|img|><|imgpad|><|endofimg|>"

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.config: DotsOCRConfig = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
660
661
        multimodal_config = vllm_config.model_config.multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Roger Wang's avatar
Roger Wang committed
662
663
664
665
666
        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config
667

668
669
670
671
672
673
674
675
676
677
678
679
680
681
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = DotsVisionTransformer(
                vision_config,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )

        with self._mark_language_model(vllm_config):
            self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=self.config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
Roger Wang's avatar
Roger Wang committed
682

683
684
685
686
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

Roger Wang's avatar
Roger Wang committed
687
    def _parse_and_validate_image_input(
688
        self, **kwargs: object
689
    ) -> DotsOCRImageInputs | None:
Roger Wang's avatar
Roger Wang committed
690
691
692
693
694
695
696
697
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
698
699
700
701
702
            return DotsOCRImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
703
704

        if image_embeds is not None:
705
706
707
708
709
            return DotsOCRImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
710
711

    def _process_image_input(
712
713
        self, image_input: DotsOCRImageInputs
    ) -> tuple[torch.Tensor, ...]:
Roger Wang's avatar
Roger Wang committed
714
715
716
717
718
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
719
            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
Roger Wang's avatar
Roger Wang committed
720
        else:
721
            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
722
723
724
725
726
727
728
729
730

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.vision_tower,
                    pixel_values,
                    grid_thw_list,
                    rope_type="rope_3d",
                )
            else:
731
                image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
732
733
                    :, : self.config.hidden_size
                ]
Roger Wang's avatar
Roger Wang committed
734
735
736

        # Split concatenated embeddings for each image item.
        merge_size = self.vision_tower.spatial_merge_size
737
738
739
740
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
Roger Wang's avatar
Roger Wang committed
741
742
743

        return image_embeds.split(sizes)

744
745
746
747
748
749
750
751
    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        merge_size = self.vision_tower.spatial_merge_size
        return num_image_tokens * (merge_size**2)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        merge_size = self.vision_tower.spatial_merge_size
        return num_vision_tokens // (merge_size**2)

752
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
Roger Wang's avatar
Roger Wang committed
753
754
755
756
757
758
759
760
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
761
        input_ids: torch.Tensor | None,
Roger Wang's avatar
Roger Wang committed
762
        positions: torch.Tensor,
763
764
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
765
        **kwargs,
766
    ) -> torch.Tensor | IntermediateTensors:
Roger Wang's avatar
Roger Wang committed
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
782
    ) -> torch.Tensor | None:
Roger Wang's avatar
Roger Wang committed
783
784
        return self.language_model.compute_logits(hidden_states)

785
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Roger Wang's avatar
Roger Wang committed
786
787
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
788
789
790
791
792
793
794
795
796
797

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_tower.merger",
            tower_model="vision_tower.",
        )