image.py 1.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5
from pathlib import Path
6
7
from typing import Literal

8
import torch
9
10
from PIL import Image

11
from .base import get_vllm_public_assets
12

13
VLM_IMAGES_DIR = "vision_model_images"
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
ImageAssetName = Literal[
    "stop_sign",
    "cherry_blossom",
    "hato",
    "2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
    "Grayscale_8bits_palette_sample_image",
    "1280px-Venn_diagram_rgb",
    "RGBA_comp",
    "237-400x300",
    "231-200x300",
    "27-500x500",
    "17-150x600",
    "handelsblatt-preview",
    "paper-11",
]
30

31
32
33

@dataclass(frozen=True)
class ImageAsset:
34
    name: ImageAssetName
35

36
37
38
39
    def get_path(self, ext: str) -> Path:
        """
        Return s3 path for given image.
        """
40
41
42
        return get_vllm_public_assets(
            filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR
        )
43

44
    @property
45
46
47
48
49
    def pil_image(self) -> Image.Image:
        return self.pil_image_ext(ext="jpg")

    def pil_image_ext(self, ext: str) -> Image.Image:
        image_path = self.get_path(ext=ext)
50
        return Image.open(image_path)
51
52
53
54
55
56

    @property
    def image_embeds(self) -> torch.Tensor:
        """
        Image embeddings, only used for testing purposes with llava 1.5.
        """
57
        image_path = self.get_path("pt")
58
        return torch.load(image_path, map_location="cpu", weights_only=True)
59
60
61
62

    def read_bytes(self, ext: str) -> bytes:
        p = Path(self.get_path(ext))
        return p.read_bytes()