vision.py 2.48 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from typing import Final, Generic, Optional, Protocol, TypeVar
3
4
5

from transformers import PretrainedConfig

6
7
8
9
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        InputProcessingContext,
                                        ProcessingCache)

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
_C = TypeVar("_C", bound=PretrainedConfig)


class VisionEncoderInfo(ABC, Generic[_C]):

    def __init__(self, vision_config: _C) -> None:
        super().__init__()

        self.vision_config = vision_config

    @abstractmethod
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_max_image_tokens(self) -> int:
        raise NotImplementedError

    @abstractmethod
34
    def get_image_size(self) -> int:
35
36
37
        raise NotImplementedError

    @abstractmethod
38
39
40
41
42
    def get_patch_size(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_patch_grid_length(self) -> int:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        raise NotImplementedError


def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
    # Avoid circular imports
    from .clip import CLIPEncoderInfo, CLIPVisionConfig
    from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
    from .siglip import SiglipEncoderInfo, SiglipVisionConfig

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPEncoderInfo(vision_config)
    if isinstance(vision_config, PixtralVisionConfig):
        return PixtralHFEncoderInfo(vision_config)
    if isinstance(vision_config, SiglipVisionConfig):
        return SiglipEncoderInfo(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83


class VisionLanguageConfig(Protocol):
    vision_config: Final[PretrainedConfig]


class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):

    def __init__(self,
                 ctx: InputProcessingContext,
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
        super().__init__(ctx,
                         cache=cache,
                         enable_sanity_checks=enable_sanity_checks)

        vision_config = self._get_hf_config().vision_config
        self._vision_encoder_info = vision_encoder_info(vision_config)

    @abstractmethod
    def _get_hf_config(self) -> VisionLanguageConfig:
        raise NotImplementedError