dots_ocr.py 26.4 KB
Newer Older
Roger Wang's avatar
Roger Wang committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
4
from typing import Annotated, Literal, TypeAlias
Roger Wang's avatar
Roger Wang committed
5
6
7
8
9
10

import torch
import torch.nn as nn
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

11
from vllm.config import VllmConfig
12
from vllm.config.multimodal import BaseDummyOptions
13
14
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
15
16
17
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Roger Wang's avatar
Roger Wang committed
18
from vllm.model_executor.layers.activation import SiluAndMul
19
20
21
from vllm.model_executor.layers.attention.mm_encoder_attention import (
    MMEncoderAttention,
)
22
from vllm.model_executor.layers.conv import Conv2dLayer
Roger Wang's avatar
Roger Wang committed
23
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Roger Wang's avatar
Roger Wang committed
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
32
33
from vllm.model_executor.layers.rotary_embedding.common import (
    ApplyRotaryEmb,
)
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
36
37
38
39
40
from vllm.model_executor.models.interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
41
from vllm.model_executor.models.module_mapping import MultiModelKeys
Roger Wang's avatar
Roger Wang committed
42
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
43
from vllm.model_executor.models.qwen2_vl import (
44
    Qwen2VisionAttention,
45
46
47
48
49
50
51
52
53
54
    Qwen2VLDummyInputsBuilder,
    Qwen2VLMultiModalProcessor,
    Qwen2VLProcessingInfo,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Roger Wang's avatar
Roger Wang committed
55
56
57
58
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.sequence import IntermediateTensors
59
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61
from vllm.v1.attention.backends.registry import AttentionBackendEnum
Roger Wang's avatar
Roger Wang committed
62

63
from .vision import is_vit_use_data_parallel, run_dp_sharded_mrope_vision_model
64

Roger Wang's avatar
Roger Wang committed
65
66
67
IMAGE_TOKEN = "<|imgpad|>"


68
69
70
71
72
73
74
75
class DotsOCRImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    """
76

77
    type: Literal["pixel_values"]
Roger Wang's avatar
Roger Wang committed
78

79
80
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
81
82


83
class DotsOCRImageEmbeddingInputs(TensorSchema):
Roger Wang's avatar
Roger Wang committed
84
    """
85
86
87
88
89
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
    """
90

91
    type: Literal["image_embeds"]
Roger Wang's avatar
Roger Wang committed
92

93
94
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
95
96


97
DotsOCRImageInputs: TypeAlias = DotsOCRImagePixelInputs | DotsOCRImageEmbeddingInputs
Roger Wang's avatar
Roger Wang committed
98
99
100
101
102
103
104
105
106
107
108


class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
Roger Wang's avatar
Roger Wang committed
110
111
112
113
114
115
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features(  # noqa: E501
        )

116
117
        image_overrides = mm_options.get("image") if mm_options else None

Roger Wang's avatar
Roger Wang committed
118
        return {
119
120
121
122
123
124
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
Roger Wang's avatar
Roger Wang committed
125
126
127
128
129
130
        }


class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> DotsOCRConfig:
        config = self.ctx.get_hf_config()
131
        if not config.__class__.__name__ == "DotsOCRConfig":
Roger Wang's avatar
Roger Wang committed
132
133
            raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")

134
        if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
Roger Wang's avatar
Roger Wang committed
135
136
137
138
            config.vision_config = DotsVisionConfig(**config.vision_config)

        return config

139
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
Roger Wang's avatar
Roger Wang committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        return {"image": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        return {"image": max_image_tokens}

    def get_hf_processor(
        self,
        **kwargs: object,
    ) -> Qwen2VLProcessor:
154
        self.get_tokenizer().image_token = IMAGE_TOKEN  # Ensure image token is set
Roger Wang's avatar
Roger Wang committed
155
156
157
158
159
160
161
162
163
164
165
166
        processor = self.ctx.get_hf_processor(
            Qwen2VLProcessor,
            **kwargs,
        )
        processor.image_token = IMAGE_TOKEN
        processor.video_token = "<|video_pad|>"
        return processor


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
167
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
Roger Wang's avatar
Roger Wang committed
168
169
170
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
171
172
173
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
Roger Wang's avatar
Roger Wang committed
174
175
176
177
178
179
180
181
182
183
184
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class PatchMerger(nn.Module):
    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        pre_norm="layernorm",
185
        prefix: str = "",
Roger Wang's avatar
Roger Wang committed
186
187
    ) -> None:
        super().__init__()
188
        use_data_parallel = is_vit_use_data_parallel()
Roger Wang's avatar
Roger Wang committed
189
190
191
192
193
194
195
196
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.pre_norm = pre_norm
        if self.pre_norm == "layernorm":
            self.ln_q = LayerNorm(context_dim, eps=1e-6)
        elif self.pre_norm == "rmsnorm":
            self.ln_q = RMSNorm(context_dim, eps=1e-6)

        self.mlp = nn.Sequential(
197
198
199
200
201
202
203
204
            ColumnParallelLinear(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.0",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
205
            nn.GELU(),
206
207
208
209
210
211
212
213
            RowParallelLinear(
                self.hidden_size,
                dim,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.2",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
214
215
216
217
218
219
220
221
222
223
224
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pre_norm:
            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        else:
            x = self.mlp(x.view(-1, self.hidden_size))
        return x


class DotsVisionAttention(nn.Module):
225
226
227
228
229
230
231
    def __init__(
        self,
        config,
        dim: int,
        num_heads: int = 16,
        bias: bool = True,
        *,
232
        quant_config: QuantizationConfig | None = None,
233
234
        prefix: str = "",
    ) -> None:
Roger Wang's avatar
Roger Wang committed
235
        super().__init__()
236
        use_data_parallel = is_vit_use_data_parallel()
Roger Wang's avatar
Roger Wang committed
237
238

        self.embed_dim = dim
239
240
241
242
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
243
244
        self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
245
246
            num_heads, self.tp_size
        )
Roger Wang's avatar
Roger Wang committed
247
        # qkv/proj follow Qwen2-VL style; bias controlled by arg
248
249
250
251
252
253
254
        self.qkv = QKVParallelLinear(
            hidden_size=dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
255
256
257
258
259
260
261
262
263
264
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
265

266
267
268
        self.attn = MMEncoderAttention(
            num_heads=self.num_attention_heads_per_partition,
            head_size=self.hidden_size_per_attention_head,
269
            scale=self.hidden_size_per_attention_head**-0.5,
270
            prefix=f"{prefix}.attn",
271
        )
Roger Wang's avatar
Roger Wang committed
272

273
274
275
276
277
        self.apply_rotary_emb = ApplyRotaryEmb(
            enforce_enable=True,
            enable_fp32_compute=True,
        )

Roger Wang's avatar
Roger Wang committed
278
279
280
281
    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
282
        rotary_pos_emb: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
283
        *,
284
        max_seqlen: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
285
286
287
288
    ) -> torch.Tensor:
        # [S, C] -> [S, B=1, C]
        x = hidden_states.unsqueeze(1)
        x, _ = self.qkv(x)
289
        q, k, v = Qwen2VisionAttention.split_qkv(self, x)
Roger Wang's avatar
Roger Wang committed
290
291
292
293
294
295
296
297
        bs = q.shape[1]
        # [S,B,H,D] -> [B,S,H,D]
        q = q.permute(1, 0, 2, 3).contiguous()
        k = k.permute(1, 0, 2, 3).contiguous()
        v = v.permute(1, 0, 2, 3).contiguous()

        if rotary_pos_emb is not None:
            qk_concat = torch.cat([q, k], dim=0)
298
299
300
301
302
            qk_rotated = self.apply_rotary_emb(
                qk_concat,
                rotary_pos_emb.cos(),
                rotary_pos_emb.sin(),
            )
Roger Wang's avatar
Roger Wang committed
303
304
            q, k = torch.chunk(qk_rotated, 2, dim=0)

305
306
307
308
309
310
311
        context_layer = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
Roger Wang's avatar
Roger Wang committed
312
313
314
315
316
317
318
319
320

        # [B,S,H,D] -> [S,B,H*D] -> [S, C]
        context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
        context_layer = context_layer.view(context_layer.shape[0], bs, -1)
        out, _ = self.proj(context_layer)
        return out.squeeze(1)


class DotsSwiGLUFFN(nn.Module):
321
322
323
324
    def __init__(
        self,
        config,
        *,
325
        quant_config: QuantizationConfig | None = None,
326
327
        prefix: str = "",
    ):
Roger Wang's avatar
Roger Wang committed
328
329
330
331
332
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.embed_dim
        bias = config.use_bias

333
        use_data_parallel = is_vit_use_data_parallel()
Roger Wang's avatar
Roger Wang committed
334
        # Referenced aimv2.py AIMv2SwiGLUFFN
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
351
352
353
354
355
356
357
358
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x

359
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        stacked_params_mapping = [
            ("fc13", "fc1", 0),
            ("fc13", "fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
384
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
385
386
387
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
Roger Wang's avatar
Roger Wang committed
388
389
390
391
392
393
394
395
396
397


class DotsPatchEmbed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_channels = config.num_channels
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.embed_dim = config.embed_dim
        self.config = config
398
        self.proj = Conv2dLayer(
Roger Wang's avatar
Roger Wang committed
399
400
401
402
403
404
405
406
            config.num_channels,
            config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
407
408
409
410
411
412
413
        x = x.view(
            -1,
            self.num_channels,
            self.temporal_patch_size,
            self.patch_size,
            self.patch_size,
        )[:, :, 0]
Roger Wang's avatar
Roger Wang committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        x = self.proj(x).view(-1, self.embed_dim)
        x = self.norm(x)
        return x


class DotsViTPreprocessor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_h = config.patch_size
        self.patch_w = config.patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.patchifier = DotsPatchEmbed(config)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
        tokens = self.patchifier(x, grid_thw)
        return tokens


class DotsVisionBlock(nn.Module):
434
435
436
437
    def __init__(
        self,
        config,
        *,
438
        quant_config: QuantizationConfig | None = None,
439
440
        prefix: str = "",
    ):
Roger Wang's avatar
Roger Wang committed
441
442
        super().__init__()

443
444
445
446
447
448
449
450
        self.attn = DotsVisionAttention(
            config,
            config.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Roger Wang's avatar
Roger Wang committed
451
        self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
452
453
454
455
456
        self.mlp = DotsSwiGLUFFN(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
Roger Wang's avatar
Roger Wang committed
457
458
        self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

459
460
461
462
463
464
    def forward(
        self,
        hidden_states: torch.Tensor,
        *,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
465
        max_seqlen: int | None = None,
466
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
467
468
469
470
471
472
473
474
475
476
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


477
class DotsVisionTransformer(nn.Module):
Roger Wang's avatar
Roger Wang committed
478
479
480
    def __init__(
        self,
        config: DotsVisionConfig,
481
        quant_config: QuantizationConfig | None = None,
Roger Wang's avatar
Roger Wang committed
482
        *,
483
484
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
Roger Wang's avatar
Roger Wang committed
485
486
        prefix: str = "",
    ) -> None:
487
        super().__init__()
Roger Wang's avatar
Roger Wang committed
488
489
490
491
492
493
494
495
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = DotsViTPreprocessor(config)

        head_dim = config.embed_dim // config.num_attention_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
        self.attn_backend = get_vit_attn_backend(
496
497
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
498
        )
499
        self.out_hidden_size = config.hidden_size
Roger Wang's avatar
Roger Wang committed
500
        # Keep blocks for compatibility with other vision towers
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        num_layers = (
            config.num_hidden_layers
            if num_hidden_layers_override is None
            else num_hidden_layers_override
        )
        self.blocks = nn.ModuleList(
            [
                DotsVisionBlock(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{i}",
                )
                for i in range(num_layers)
            ]
        )
Roger Wang's avatar
Roger Wang committed
516
        if require_post_norm is None:
517
            require_post_norm = len(self.blocks) == config.num_hidden_layers
Roger Wang's avatar
Roger Wang committed
518
        if require_post_norm and self.config.post_norm:
519
            self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
Roger Wang's avatar
Roger Wang committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        else:
            self.post_trunk_norm = None

        self.merger = PatchMerger(
            dim=config.hidden_size,
            context_dim=config.embed_dim,
            spatial_merge_size=config.spatial_merge_size,
        )

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.patchifier.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.patchifier.proj.weight.device

537
    def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
Roger Wang's avatar
Roger Wang committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
559
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
Roger Wang's avatar
Roger Wang committed
560
561
562

        return pos_ids

563
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
564
565
        pos_ids = self.get_pos_ids_by_grid(grid_thw)
        pos_ids = torch.cat(pos_ids, dim=0)
566
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
Roger Wang's avatar
Roger Wang committed
567
568
569
570
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

571
572
    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
        max_seqlen = None
573
        if (
574
575
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
576
        ):
577
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
578
        return max_seqlen
Roger Wang's avatar
Roger Wang committed
579

580
581
582
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
    ) -> torch.Tensor:
583
584
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

585
        # Convert grid_thw to tensor (always expecting list format now)
586
        grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
Roger Wang's avatar
Roger Wang committed
587
588
589
590
        hidden_states = hidden_states.to(self.dtype)
        hidden_states = self.patch_embed(hidden_states, grid_thw)

        cu_seqlens = torch.repeat_interleave(
591
592
593
594
595
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
596
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
Roger Wang's avatar
Roger Wang committed
597

598
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)
Roger Wang's avatar
Roger Wang committed
599
        for blk in self.blocks:
600
601
602
603
604
605
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
            )
Roger Wang's avatar
Roger Wang committed
606
607
608
609
610
611
612
613
614
615
616
617
618

        if self.post_trunk_norm is not None:
            hidden_states = self.post_trunk_norm(hidden_states)

        hidden_states = self.merger(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=DotsOCRProcessingInfo,
    dummy_inputs=DotsOCRDummyInputsBuilder,
)
619
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
Roger Wang's avatar
Roger Wang committed
620
621
622
623
624
625
626
627
628
629
630
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".attn.qkv_proj.": ".attn.qkv.",
            ".attn.out_proj.": ".attn.proj.",
        },
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        },
    )

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
        ".attn.qkv": [".attn.qkv"],
        "fc13": ["fc1", "fc3"],
    }
    supports_encoder_tp_data = True

Roger Wang's avatar
Roger Wang committed
646
    @classmethod
647
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
Roger Wang's avatar
Roger Wang committed
648
649
650
651
652
653
654
655
        if modality.startswith("image"):
            return "<|img|><|imgpad|><|endofimg|>"

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.config: DotsOCRConfig = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
656
657
        multimodal_config = vllm_config.model_config.multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Roger Wang's avatar
Roger Wang committed
658
659
660
661
662
        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config
663

664
665
666
667
668
669
670
671
672
673
674
675
676
677
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = DotsVisionTransformer(
                vision_config,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )

        with self._mark_language_model(vllm_config):
            self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=self.config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )
Roger Wang's avatar
Roger Wang committed
678

679
680
681
682
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

Roger Wang's avatar
Roger Wang committed
683
    def _parse_and_validate_image_input(
684
        self, **kwargs: object
685
    ) -> DotsOCRImageInputs | None:
Roger Wang's avatar
Roger Wang committed
686
687
688
689
690
691
692
693
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
694
695
696
697
698
            return DotsOCRImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
699
700

        if image_embeds is not None:
701
702
703
704
705
            return DotsOCRImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
706
707

    def _process_image_input(
708
709
        self, image_input: DotsOCRImageInputs
    ) -> tuple[torch.Tensor, ...]:
Roger Wang's avatar
Roger Wang committed
710
711
712
713
714
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
715
            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
Roger Wang's avatar
Roger Wang committed
716
        else:
717
            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
718
719
720
721
722
723
724
725
726

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.vision_tower,
                    pixel_values,
                    grid_thw_list,
                    rope_type="rope_3d",
                )
            else:
727
                image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
728
729
                    :, : self.config.hidden_size
                ]
Roger Wang's avatar
Roger Wang committed
730
731
732

        # Split concatenated embeddings for each image item.
        merge_size = self.vision_tower.spatial_merge_size
733
734
735
736
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
Roger Wang's avatar
Roger Wang committed
737
738
739

        return image_embeds.split(sizes)

740
741
742
743
744
745
746
747
    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        merge_size = self.vision_tower.spatial_merge_size
        return num_image_tokens * (merge_size**2)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        merge_size = self.vision_tower.spatial_merge_size
        return num_vision_tokens // (merge_size**2)

748
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
Roger Wang's avatar
Roger Wang committed
749
750
751
752
753
754
755
756
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
757
        input_ids: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
758
        positions: torch.Tensor,
759
760
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
Roger Wang's avatar
Roger Wang committed
761
        **kwargs,
762
    ) -> torch.Tensor | IntermediateTensors:
Roger Wang's avatar
Roger Wang committed
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
778
    ) -> torch.Tensor | None:
Roger Wang's avatar
Roger Wang committed
779
780
        return self.language_model.compute_logits(hidden_states)

781
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Roger Wang's avatar
Roger Wang committed
782
783
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
784
785
786
787
788
789
790
791
792

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_tower.merger",
            tower_model="vision_tower.",
zhuwenwen's avatar
zhuwenwen committed
793
        )