dots_ocr.py 29.6 KB
Newer Older
Roger Wang's avatar
Roger Wang committed
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
4
from typing import Annotated, Literal, Optional, Union
Roger Wang's avatar
Roger Wang committed
5
6
7
8
9
10
11

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from transformers.models.qwen2_vl import Qwen2VLProcessor

12
from vllm.attention.backends.registry import _Backend
13
14
15
16
from vllm.attention.layer import (
    check_upstream_fa_availability,
    maybe_get_vit_flash_attn_backend,
)
Roger Wang's avatar
Roger Wang committed
17
from vllm.config import VllmConfig
18
from vllm.config.multimodal import BaseDummyOptions
19
20
from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import (
21
22
23
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Roger Wang's avatar
Roger Wang committed
24
25
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
26
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
Roger Wang's avatar
Roger Wang committed
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
35
36
37
38
39
from vllm.model_executor.models.interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
40
from vllm.model_executor.models.module_mapping import MultiModelKeys
Roger Wang's avatar
Roger Wang committed
41
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
42
from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VisionAttention
43
44
45
46
47
48
49
50
51
52
53
from vllm.model_executor.models.qwen2_vl import (
    Qwen2VLDummyInputsBuilder,
    Qwen2VLMultiModalProcessor,
    Qwen2VLProcessingInfo,
)
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
Roger Wang's avatar
Roger Wang committed
54
55
56
57
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.sequence import IntermediateTensors
58
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
60

61
62
from .vision import run_dp_sharded_mrope_vision_model

Roger Wang's avatar
Roger Wang committed
63
64
65
IMAGE_TOKEN = "<|imgpad|>"


66
67
68
69
70
71
72
73
class DotsOCRImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over each image over each prompt in
              the batch
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size
    """
74

75
    type: Literal["pixel_values"]
Roger Wang's avatar
Roger Wang committed
76

77
78
    pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
79
80


81
class DotsOCRImageEmbeddingInputs(TensorSchema):
Roger Wang's avatar
Roger Wang committed
82
    """
83
84
85
86
87
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images
    """
88

89
    type: Literal["image_embeds"]
Roger Wang's avatar
Roger Wang committed
90

91
92
    image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
    image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Roger Wang's avatar
Roger Wang committed
93
94


95
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, DotsOCRImageEmbeddingInputs]
Roger Wang's avatar
Roger Wang committed
96
97
98
99
100
101
102
103
104
105
106


class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        return IMAGE_TOKEN * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
107
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
Roger Wang's avatar
Roger Wang committed
108
109
110
111
112
113
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features(  # noqa: E501
        )

114
115
        image_overrides = mm_options.get("image") if mm_options else None

Roger Wang's avatar
Roger Wang committed
116
        return {
117
118
119
120
121
122
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
Roger Wang's avatar
Roger Wang committed
123
124
125
126
127
128
        }


class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
    def get_hf_config(self) -> DotsOCRConfig:
        config = self.ctx.get_hf_config()
129
        if not config.__class__.__name__ == "DotsOCRConfig":
Roger Wang's avatar
Roger Wang committed
130
131
            raise TypeError(f"Expected DotsOCRConfig, got {type(config)}")

132
        if hasattr(config, "vision_config") and isinstance(config.vision_config, dict):
Roger Wang's avatar
Roger Wang committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            config.vision_config = DotsVisionConfig(**config.vision_config)

        return config

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        max_image_tokens = self.get_max_image_tokens()
        return {"image": max_image_tokens}

    def get_hf_processor(
        self,
        **kwargs: object,
    ) -> Qwen2VLProcessor:
152
        self.get_tokenizer().image_token = IMAGE_TOKEN  # Ensure image token is set
Roger Wang's avatar
Roger Wang committed
153
154
155
156
157
158
159
160
161
162
163
        processor = self.ctx.get_hf_processor(
            Qwen2VLProcessor,
            **kwargs,
        )
        processor.image_token = IMAGE_TOKEN
        processor.video_token = "<|video_pad|>"
        return processor


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
164
165
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
Roger Wang's avatar
Roger Wang committed
166
167
168
    return torch.cat((-x2, x1), dim=-1)


169
170
171
def apply_rotary_pos_emb_vision(
    tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    orig_dtype = tensor.dtype
    tensor = tensor.float()

    cos = freqs.cos()
    sin = freqs.sin()

    cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
    sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()

    output = (tensor * cos) + (rotate_half(tensor) * sin)

    output = output.to(orig_dtype)

    return output


class VisionRotaryEmbedding(nn.Module):
    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
191
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
Roger Wang's avatar
Roger Wang committed
192
193
194
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
195
196
197
        seq = torch.arange(
            seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
Roger Wang's avatar
Roger Wang committed
198
199
200
201
202
203
204
205
206
207
208
        freqs = torch.outer(seq, self.inv_freq)
        return freqs


class PatchMerger(nn.Module):
    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        pre_norm="layernorm",
209
210
        prefix: str = "",
        use_data_parallel: bool = False,
Roger Wang's avatar
Roger Wang committed
211
212
213
214
215
216
217
218
219
220
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.pre_norm = pre_norm
        if self.pre_norm == "layernorm":
            self.ln_q = LayerNorm(context_dim, eps=1e-6)
        elif self.pre_norm == "rmsnorm":
            self.ln_q = RMSNorm(context_dim, eps=1e-6)

        self.mlp = nn.Sequential(
221
222
223
224
225
226
227
228
            ColumnParallelLinear(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.0",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
229
            nn.GELU(),
230
231
232
233
234
235
236
237
            RowParallelLinear(
                self.hidden_size,
                dim,
                bias=True,
                return_bias=False,
                prefix=f"{prefix}.2",
                disable_tp=use_data_parallel,
            ),
Roger Wang's avatar
Roger Wang committed
238
239
240
241
242
243
244
245
246
247
248
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.pre_norm:
            x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
        else:
            x = self.mlp(x.view(-1, self.hidden_size))
        return x


class DotsVisionAttention(nn.Module):
249
250
251
252
253
254
255
256
257
258
259
    def __init__(
        self,
        config,
        dim: int,
        num_heads: int = 16,
        bias: bool = True,
        *,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ) -> None:
Roger Wang's avatar
Roger Wang committed
260
261
262
        super().__init__()

        self.embed_dim = dim
263
264
265
266
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
267
268
        self.hidden_size_per_attention_head = dist_utils.divide(dim, num_heads)
        self.num_attention_heads_per_partition = dist_utils.divide(
269
270
            num_heads, self.tp_size
        )
Roger Wang's avatar
Roger Wang committed
271
        # qkv/proj follow Qwen2-VL style; bias controlled by arg
272
273
274
275
276
277
278
        self.qkv = QKVParallelLinear(
            hidden_size=dim,
            head_size=self.hidden_size_per_attention_head,
            total_num_heads=num_heads,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv",
279
280
281
282
283
284
285
286
287
288
            disable_tp=use_data_parallel,
        )
        self.proj = RowParallelLinear(
            input_size=dim,
            output_size=dim,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.proj",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
289
        # Select attention backend
290
        self.attn_backend = get_vit_attn_backend(
291
292
            self.hidden_size_per_attention_head, torch.get_default_dtype()
        )
Roger Wang's avatar
Roger Wang committed
293
        self.use_upstream_fa = False
294

295
296
        self.attn_backend, self.flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
297
298
299
                self.attn_backend,
                self.use_upstream_fa,
            )
300
        )
Roger Wang's avatar
Roger Wang committed
301
        if self.attn_backend not in {
302
303
304
305
            _Backend.FLASH_ATTN,
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
            _Backend.ROCM_AITER_FA,
Roger Wang's avatar
Roger Wang committed
306
307
        }:
            raise RuntimeError(
308
309
                f"Unsupported vision attention backend: {self.attn_backend}"
            )
Roger Wang's avatar
Roger Wang committed
310
        self.is_flash_attn_backend = self.attn_backend in {
311
312
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
Roger Wang's avatar
Roger Wang committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        }

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        *,
        max_seqlen: Optional[int] = None,
        seqlens: Optional[list[int]] = None,
    ) -> torch.Tensor:
        # [S, C] -> [S, B=1, C]
        x = hidden_states.unsqueeze(1)
        x, _ = self.qkv(x)
327
        q, k, v = Qwen2_5_VisionAttention.split_qkv(self, x)
Roger Wang's avatar
Roger Wang committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        bs = q.shape[1]
        # [S,B,H,D] -> [B,S,H,D]
        q = q.permute(1, 0, 2, 3).contiguous()
        k = k.permute(1, 0, 2, 3).contiguous()
        v = v.permute(1, 0, 2, 3).contiguous()

        if rotary_pos_emb is not None:
            qk_concat = torch.cat([q, k], dim=0)
            qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
            q, k = torch.chunk(qk_rotated, 2, dim=0)

        if self.is_flash_attn_backend:
            q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3])
            k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3])
            v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3])
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            output = self.flash_attn_varlen_func(
                q_,
                k_,
                v_,
                cu_seqlens_q=cu_seqlens,
                cu_seqlens_k=cu_seqlens,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                dropout_p=0.0,
                causal=False,
            )
            context_layer = output.view(
                bs,
                -1,
                self.num_attention_heads_per_partition,
                self.hidden_size_per_attention_head,
            )
Roger Wang's avatar
Roger Wang committed
360
361
362
363
364
365
366
367
        elif self.attn_backend == _Backend.TORCH_SDPA:
            outputs = []
            for i in range(1, len(cu_seqlens)):
                s = int(cu_seqlens[i - 1])
                e = int(cu_seqlens[i])
                q_i = q[:, s:e].permute(0, 2, 1, 3)
                k_i = k[:, s:e].permute(0, 2, 1, 3)
                v_i = v[:, s:e].permute(0, 2, 1, 3)
368
                out_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
Roger Wang's avatar
Roger Wang committed
369
370
371
372
373
374
                out_i = out_i.permute(0, 2, 1, 3)
                outputs.append(out_i)
            context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0]
        elif self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops
            from xformers.ops.fmha.attn_bias import BlockDiagonalMask
375
376
377
378

            attn_bias = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens, kv_seqlen=None, device=q.device
            )
Roger Wang's avatar
Roger Wang committed
379
            context_layer = xops.memory_efficient_attention_forward(
380
381
                q, k, v, attn_bias=attn_bias, p=0, scale=None
            )
Roger Wang's avatar
Roger Wang committed
382
383
384
385
386
387
388
389
390
391
392
        else:
            raise RuntimeError("Unsupported attention backend")

        # [B,S,H,D] -> [S,B,H*D] -> [S, C]
        context_layer = context_layer.permute(1, 0, 2, 3).contiguous()
        context_layer = context_layer.view(context_layer.shape[0], bs, -1)
        out, _ = self.proj(context_layer)
        return out.squeeze(1)


class DotsSwiGLUFFN(nn.Module):
393
394
395
396
397
398
399
400
    def __init__(
        self,
        config,
        *,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
Roger Wang's avatar
Roger Wang committed
401
402
403
404
405
406
        super().__init__()
        hidden_features = config.intermediate_size
        in_features = config.embed_dim
        bias = config.use_bias

        # Referenced aimv2.py AIMv2SwiGLUFFN
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
        self.fc13 = MergedColumnParallelLinear(
            in_features,
            [hidden_features] * 2,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc13",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
423
424
425
426
427
428
429
430
        self.act_fn = SiluAndMul()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.fc13(x)
        x = self.act_fn(x)
        x, _ = self.fc2(x)
        return x

431
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
        stacked_params_mapping = [
            ("fc13", "fc1", 0),
            ("fc13", "fc3", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
456
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
457
458
459
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
Roger Wang's avatar
Roger Wang committed
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


class DotsPatchEmbed(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_channels = config.num_channels
        self.patch_size = config.patch_size
        self.temporal_patch_size = config.temporal_patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.proj = nn.Conv2d(
            config.num_channels,
            config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
479
480
481
482
483
484
485
        x = x.view(
            -1,
            self.num_channels,
            self.temporal_patch_size,
            self.patch_size,
            self.patch_size,
        )[:, :, 0]
Roger Wang's avatar
Roger Wang committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        x = self.proj(x).view(-1, self.embed_dim)
        x = self.norm(x)
        return x


class DotsViTPreprocessor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_h = config.patch_size
        self.patch_w = config.patch_size
        self.embed_dim = config.embed_dim
        self.config = config
        self.patchifier = DotsPatchEmbed(config)

    def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
        tokens = self.patchifier(x, grid_thw)
        return tokens


class DotsVisionBlock(nn.Module):
506
507
508
509
510
511
512
513
    def __init__(
        self,
        config,
        *,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_data_parallel: bool = False,
    ):
Roger Wang's avatar
Roger Wang committed
514
515
        super().__init__()

516
517
518
519
520
521
522
523
524
        self.attn = DotsVisionAttention(
            config,
            config.embed_dim,
            num_heads=config.num_attention_heads,
            bias=config.use_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
525
        self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
526
527
528
529
530
531
        self.mlp = DotsSwiGLUFFN(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
532
533
        self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)

534
535
536
537
538
539
540
541
542
    def forward(
        self,
        hidden_states: torch.Tensor,
        *,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
        max_seqlen: Optional[int] = None,
        seqlens: Optional[list[int]] = None,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
543
544
545
546
547
548
549
550
551
552
553
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


554
class DotsVisionTransformer(nn.Module):
Roger Wang's avatar
Roger Wang committed
555
556
557
558
559
560
561
562
    def __init__(
        self,
        config: DotsVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        require_post_norm: Optional[bool] = None,
        prefix: str = "",
563
        use_data_parallel: bool = False,
Roger Wang's avatar
Roger Wang committed
564
    ) -> None:
565
        super().__init__()
Roger Wang's avatar
Roger Wang committed
566
567
568
569
570
571
572
573
        self.config = config
        self.spatial_merge_size = config.spatial_merge_size

        self.patch_embed = DotsViTPreprocessor(config)

        head_dim = config.embed_dim // config.num_attention_heads
        self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
        self.attn_backend = get_vit_attn_backend(
574
575
576
577
578
            head_size=head_dim, dtype=torch.get_default_dtype()
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
Roger Wang's avatar
Roger Wang committed
579
            self.attn_backend = _Backend.FLASH_ATTN
580
        self.out_hidden_size = config.hidden_size
Roger Wang's avatar
Roger Wang committed
581
        # Keep blocks for compatibility with other vision towers
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        num_layers = (
            config.num_hidden_layers
            if num_hidden_layers_override is None
            else num_hidden_layers_override
        )
        self.blocks = nn.ModuleList(
            [
                DotsVisionBlock(
                    config,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{i}",
                    use_data_parallel=use_data_parallel,
                )
                for i in range(num_layers)
            ]
        )
Roger Wang's avatar
Roger Wang committed
598
        if require_post_norm is None:
599
            require_post_norm = len(self.blocks) == config.num_hidden_layers
Roger Wang's avatar
Roger Wang committed
600
        if require_post_norm and self.config.post_norm:
601
            self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
Roger Wang's avatar
Roger Wang committed
602
603
604
605
606
607
608
        else:
            self.post_trunk_norm = None

        self.merger = PatchMerger(
            dim=config.hidden_size,
            context_dim=config.embed_dim,
            spatial_merge_size=config.spatial_merge_size,
609
            use_data_parallel=use_data_parallel,
Roger Wang's avatar
Roger Wang committed
610
611
612
613
614
615
616
617
618
619
        )

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.patchifier.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.patchifier.proj.weight.device

620
    def get_pos_ids_by_grid(self, grid_thw: list[list[int]]) -> list[torch.Tensor]:
Roger Wang's avatar
Roger Wang committed
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
642
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
Roger Wang's avatar
Roger Wang committed
643
644
645

        return pos_ids

646
    def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
647
648
        pos_ids = self.get_pos_ids_by_grid(grid_thw)
        pos_ids = torch.cat(pos_ids, dim=0)
649
        max_grid_size = max(max(h, w) for _, h, w in grid_thw)
Roger Wang's avatar
Roger Wang committed
650
651
652
653
654
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def compute_attn_mask_seqlen(
655
        self, cu_seqlens: torch.Tensor
Roger Wang's avatar
Roger Wang committed
656
657
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
658
659
660
661
        if (
            self.attn_backend == _Backend.FLASH_ATTN
            or self.attn_backend == _Backend.ROCM_AITER_FA
        ):
Roger Wang's avatar
Roger Wang committed
662
663
664
665
666
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

667
668
669
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: list[list[int]]
    ) -> torch.Tensor:
670
671
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

672
        # Convert grid_thw to tensor (always expecting list format now)
673
        grid_thw = torch.tensor(grid_thw, device=hidden_states.device, dtype=torch.long)
Roger Wang's avatar
Roger Wang committed
674
675
676
677
        hidden_states = hidden_states.to(self.dtype)
        hidden_states = self.patch_embed(hidden_states, grid_thw)

        cu_seqlens = torch.repeat_interleave(
678
679
680
681
682
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
683
        cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
Roger Wang's avatar
Roger Wang committed
684
685
686

        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
        for blk in self.blocks:
687
688
689
690
691
692
693
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
Roger Wang's avatar
Roger Wang committed
694
695
696
697
698
699
700
701
702
703
704
705
706

        if self.post_trunk_norm is not None:
            hidden_states = self.post_trunk_norm(hidden_states)

        hidden_states = self.merger(hidden_states)
        return hidden_states


@MULTIMODAL_REGISTRY.register_processor(
    Qwen2VLMultiModalProcessor,
    info=DotsOCRProcessingInfo,
    dummy_inputs=DotsOCRDummyInputsBuilder,
)
707
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
708
709
    merge_by_field_config = True

Roger Wang's avatar
Roger Wang committed
710
711
712
713
714
715
716
717
718
719
720
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".attn.qkv_proj.": ".attn.qkv.",
            ".attn.out_proj.": ".attn.proj.",
        },
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        },
    )

721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
        ".attn.qkv": [".attn.qkv"],
        "fc13": ["fc1", "fc3"],
    }
    supports_encoder_tp_data = True

Roger Wang's avatar
Roger Wang committed
736
737
738
739
740
741
742
743
744
745
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|img|><|imgpad|><|endofimg|>"

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        self.config: DotsOCRConfig = vllm_config.model_config.hf_config
        self.quant_config = vllm_config.quant_config
746
747
        multimodal_config = vllm_config.model_config.multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
Roger Wang's avatar
Roger Wang committed
748
749
750
751
752
753
754
755
756
        if isinstance(self.config.vision_config, dict):
            vision_config = DotsVisionConfig(**self.config.vision_config)
            self.config.vision_config = vision_config
        else:
            vision_config = self.config.vision_config
        self.vision_tower = DotsVisionTransformer(
            vision_config,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "vision_tower"),
757
758
            use_data_parallel=self.use_data_parallel,
        )
Roger Wang's avatar
Roger Wang committed
759
760
761
762
763
764
765
766
        self.language_model: Qwen2ForCausalLM = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=self.config,
            prefix=maybe_prefix(prefix, "language_model"),
            architectures=["Qwen2ForCausalLM"],
        )

    def _parse_and_validate_image_input(
767
768
        self, **kwargs: object
    ) -> Optional[DotsOCRImageInputs]:
Roger Wang's avatar
Roger Wang committed
769
770
771
772
773
774
775
776
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
777
778
779
780
781
            return DotsOCRImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
782
783

        if image_embeds is not None:
784
785
786
787
788
            return DotsOCRImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )
Roger Wang's avatar
Roger Wang committed
789
790

    def _process_image_input(
791
792
        self, image_input: DotsOCRImageInputs
    ) -> tuple[torch.Tensor, ...]:
Roger Wang's avatar
Roger Wang committed
793
794
795
796
797
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
798
            image_embeds = image_input["image_embeds"].type(self.vision_tower.dtype)
Roger Wang's avatar
Roger Wang committed
799
        else:
800
            pixel_values = image_input["pixel_values"].type(self.vision_tower.dtype)
801
802
803
804
805
806
807
808
809

            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.vision_tower,
                    pixel_values,
                    grid_thw_list,
                    rope_type="rope_3d",
                )
            else:
810
                image_embeds = self.vision_tower(pixel_values, grid_thw_list)[
811
812
                    :, : self.config.hidden_size
                ]
Roger Wang's avatar
Roger Wang committed
813
814
815

        # Split concatenated embeddings for each image item.
        merge_size = self.vision_tower.spatial_merge_size
816
817
818
819
        sizes = (
            torch.tensor(grid_thw_list, dtype=torch.long).prod(-1)
            // (merge_size * merge_size)
        ).tolist()
Roger Wang's avatar
Roger Wang committed
820
821
822
823
824
825

        return image_embeds.split(sizes)

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

826
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
Roger Wang's avatar
Roger Wang committed
827
828
829
830
831
832
833
834
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
835
        input_ids: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
836
837
838
839
840
841
842
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None
843
844
845
846
847
848
849
850
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(
                input_ids,
                vision_embeddings,
                is_multimodal=input_ids == self.config.image_token_id,
            )
            input_ids = None
Roger Wang's avatar
Roger Wang committed
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states)

867
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Roger Wang's avatar
Roger Wang committed
868
869
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
870
871
872
873
874
875
876
877
878
879

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_tower.merger",
            tower_model="vision_tower.",
        )