image.py 1.53 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from dataclasses import dataclass
5
from pathlib import Path
6
7
from typing import Literal

8
import torch
9
10
from PIL import Image

11
from .base import get_vllm_public_assets
12

13
VLM_IMAGES_DIR = "vision_model_images"
14

15
16
17
18
19
20
ImageAssetName = Literal["stop_sign", "cherry_blossom", "hato",
                         "2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
                         "Grayscale_8bits_palette_sample_image",
                         "1280px-Venn_diagram_rgb", "RGBA_comp", "237-400x300",
                         "231-200x300", "27-500x500", "17-150x600",
                         "handelsblatt-preview", "paper-11"]
21

22
23
24

@dataclass(frozen=True)
class ImageAsset:
25
    name: ImageAssetName
26

27
28
29
30
31
32
33
    def get_path(self, ext: str) -> Path:
        """
        Return s3 path for given image.
        """
        return get_vllm_public_assets(filename=f"{self.name}.{ext}",
                                      s3_prefix=VLM_IMAGES_DIR)

34
    @property
35
36
37
    def pil_image(self, ext="jpg") -> Image.Image:

        image_path = self.get_path(ext)
38
        return Image.open(image_path)
39
40
41
42
43
44

    @property
    def image_embeds(self) -> torch.Tensor:
        """
        Image embeddings, only used for testing purposes with llava 1.5.
        """
45
        image_path = self.get_path('pt')
46
        return torch.load(image_path, map_location="cpu", weights_only=True)
47
48
49
50

    def read_bytes(self, ext: str) -> bytes:
        p = Path(self.get_path(ext))
        return p.read_bytes()