pixtral.py 2.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from mistral_common.protocol.instruct.chunk import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.image_utils import ImageInput

from vllm.tokenizers.mistral import MistralTokenizer


class MistralCommonImageProcessor:
    """
    Provide a HF-compatible interface for
    `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
    """

    def __init__(self, mm_encoder: ImageEncoder) -> None:
        self.mm_encoder = mm_encoder

    def __call__(
        self,
        images: ImageInput,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> BatchFeature:
        images_lst = [images] if not isinstance(images, list) else images

        images_processed = list[torch.Tensor]()

        for image in images_lst:
            image_inputs = self.mm_encoder(ImageChunk(image=image))
            image_processed = torch.tensor(image_inputs.image)

            images_processed.append(image_processed)

        return BatchFeature({"images": images_processed}, tensor_type=return_tensors)

    def get_number_of_image_patches(
        self,
        height: int,
        width: int,
    ) -> tuple[int, int, int]:
        image = Image.new("RGB", (width, height))
        ncols, nrows = self.mm_encoder._image_to_num_tokens(image)
        return ncols * nrows, nrows, ncols


class MistralCommonPixtralProcessor(ProcessorMixin):
    attributes = ["image_processor", "tokenizer"]

53
54
55
56
57
    def __init__(
        self,
        tokenizer: MistralTokenizer,
        image_processor: MistralCommonImageProcessor,
    ) -> None:
58
        self.tokenizer = tokenizer.transformers_tokenizer
59
60
61
62
63

        # Back-compatibility for Transformers v4
        if not hasattr(self.tokenizer, "init_kwargs"):
            self.tokenizer.init_kwargs = {}

64
        self.image_processor = image_processor
65

66
67
68
69
        image_special_ids = self.image_processor.mm_encoder.special_ids
        self.image_break_id = image_special_ids.img_break
        self.image_token_id = image_special_ids.img
        self.image_end_id = image_special_ids.img_end