dots_ocr.py 1.82 KB
Newer Older
qrskannbara's avatar
qrskannbara committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Optional

from transformers import AutoProcessor, Qwen2_5_VLProcessor
from transformers.image_processing_utils import BaseImageProcessor
from transformers.models.qwen2 import Qwen2Config

from sglang.srt.configs.dots_vlm import DotsVisionConfig


class DotsOCRConfig(Qwen2Config):
    model_type = "dots_ocr"

    def __init__(
        self,
        image_token_id=151665,
        video_token_id=151656,
        vision_config: Optional[dict] = None,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.image_token_id = image_token_id
        self.video_token_id = video_token_id
        self.vision_config = DotsVisionConfig(**(vision_config or {}))

    def save_pretrained(self, save_directory, **kwargs):
        self._auto_class = None
        super().save_pretrained(save_directory, **kwargs)


class DummyVideoProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values"]

    def __call__(self, *args, **kwargs):
        return None


class DotsVLProcessor(Qwen2_5_VLProcessor):
    def __init__(
        self,
        image_processor=None,
        tokenizer=None,
        video_processor=None,
        chat_template=None,
        **kwargs
    ):
        if video_processor is None:
            video_processor = DummyVideoProcessor()
        super().__init__(
            image_processor, tokenizer, video_processor, chat_template=chat_template
        )
        self.image_token = (
            "<|imgpad|>"
            if not hasattr(tokenizer, "image_token")
            else tokenizer.image_token
        )
        self.image_token_id = (
            tokenizer.image_token_id
            if getattr(tokenizer, "image_token_id", None) is not None
            else tokenizer.convert_tokens_to_ids(self.image_token)
        )


AutoProcessor.register(DotsOCRConfig, DotsVLProcessor)