smolvlm.py 1.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import Optional
5
6
7
8
9
10
11
12

from transformers import SmolVLMProcessor

from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY

# yapf: disable
from .idefics3 import Idefics3DummyInputsBuilder as SmolVLMDummyInputsBuilder
13
from .idefics3 import Idefics3ForConditionalGeneration, Idefics3ProcessingInfo
14
15
16
17
18
19
from .idefics3 import Idefics3MultiModalProcessor as SmolVLMMultiModalProcessor

# yapf: enable


class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
20
    def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor:
21
22
23
        return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)

    def _get_image_token(
24
25
        self, processor: Optional[SmolVLMProcessor]
    ) -> tuple[str, str]:
26
27
28
29
30
31
32
33
        if processor is None:
            processor = self.get_hf_processor()
        image_token = processor.image_token
        fake_image_token = processor.fake_image_token
        global_image_token = processor.global_image_token
        return image_token, fake_image_token, global_image_token


34
35
36
37
38
@MULTIMODAL_REGISTRY.register_processor(
    SmolVLMMultiModalProcessor,
    info=SmolVLMProcessingInfo,
    dummy_inputs=SmolVLMDummyInputsBuilder,
)
39
40
41
42
43
44
class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(
            vllm_config=vllm_config,
            prefix=prefix,
        )