image.py 1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
from dataclasses import dataclass
from typing import Literal

7
import torch
8
9
from PIL import Image

10
from .base import get_vllm_public_assets
11

12
VLM_IMAGES_DIR = "vision_model_images"
13

14
15
ImageAssetName = Literal["stop_sign", "cherry_blossom"]

16
17
18

@dataclass(frozen=True)
class ImageAsset:
19
    name: ImageAssetName
20

21
    @property
22
    def pil_image(self) -> Image.Image:
23
24
25
        image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
                                            s3_prefix=VLM_IMAGES_DIR)
        return Image.open(image_path)
26
27
28
29
30
31
32
33

    @property
    def image_embeds(self) -> torch.Tensor:
        """
        Image embeddings, only used for testing purposes with llava 1.5.
        """
        image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
                                            s3_prefix=VLM_IMAGES_DIR)
34
        return torch.load(image_path, map_location="cpu", weights_only=True)