lightonocr.py 6.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import TypeVar

import torch
import torch.nn as nn
from transformers import (
    BatchFeature,
    PixtralVisionConfig,
)

from vllm.config import VllmConfig
from vllm.model_executor.models.mistral3 import (
    Mistral3DummyInputsBuilder,
    Mistral3ForConditionalGeneration,
    Mistral3MultiModalProjector,
    Mistral3ProcessingInfo,
19
    init_vision_tower_for_mistral3,
20
21
22
23
24
25
26
27
28
)
from vllm.model_executor.models.pixtral import PixtralHFEncoderInfo
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)

_I = TypeVar("_I", bound=Mistral3ProcessingInfo)


class LightOnOCRMultiModalProcessor(BaseMultiModalProcessor[Mistral3ProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

        # NOTE: LightOnOCR does not use break/end tokens, so we remove them here.
        input_ids = processed_outputs.get("input_ids")
        if input_ids is not None:
            processor = self.info.get_hf_processor()
            tokenizer = self.info.get_tokenizer()
            vocab = tokenizer.get_vocab()

            break_id = vocab.get(processor.image_break_token)
            end_id = vocab.get(processor.image_end_token)

            # create mask to remove break/end tokens
            keep_mask = ~torch.isin(
                input_ids,
                torch.tensor([break_id, end_id]),
            )

            processed_outputs["input_ids"] = input_ids[keep_mask].unsqueeze(0)
            if "attention_mask" in processed_outputs:
                processed_outputs["attention_mask"] = processed_outputs[
                    "attention_mask"
                ][keep_mask].unsqueeze(0)

        # un-pad pixel_values per-image so caches remain independent.
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
103
        out_mm_kwargs: MultiModalKwargsItems,
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    ) -> Sequence[PromptUpdate]:
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)

        def replace(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            size = images.get_image_size(item_idx)
            ncols, nrows = encoder_info.get_patch_grid_size(
                image_width=size.width, image_height=size.height
            )
            # break/end tokens are not used in LightOnOCR
            tokens = [image_token_id] * (ncols * nrows)
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)

        return [
            PromptReplacement(
                modality="image", target=[image_token_id], replacement=replace
            )
        ]


@MULTIMODAL_REGISTRY.register_processor(
129
130
    LightOnOCRMultiModalProcessor,
    info=Mistral3ProcessingInfo,
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    dummy_inputs=Mistral3DummyInputsBuilder,
)
class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_encoder.": "vision_tower.",
            "model.vision_projection.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
        nn.Module.__init__(self)
145

146
147
148
149
150
151
152
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.config = config
        self.multimodal_config = multimodal_config

153
        with self._mark_tower_model(vllm_config, "image"):
154
            self.vision_tower = init_vision_tower_for_mistral3(
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                config,
                quant_config=quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
            self.multi_modal_projector = Mistral3MultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                spatial_merge_size=config.spatial_merge_size,
                patch_size=config.vision_config.patch_size,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
170

171
172
173
174
175
176
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
177
178
179
180
181
182
183
184

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)