hasher.py 3.37 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pickle
5
6
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Optional
7
8
9
10
11
12
13

import numpy as np
import torch
from blake3 import blake3
from PIL import Image

from vllm.logger import init_logger
14
from vllm.multimodal.image import convert_image_mode
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

if TYPE_CHECKING:
    from vllm.inputs import TokensPrompt

logger = init_logger(__name__)

MultiModalHashDict = Mapping[str, list[str]]
"""
A dictionary containing hashes for items in each modality.
"""


class MultiModalHasher:

    @classmethod
    def serialize_item(cls, obj: object) -> bytes:
        # Simple cases
        if isinstance(obj, str):
            return obj.encode("utf-8")
        if isinstance(obj, bytes):
            return obj
36
37
        if isinstance(obj, (int, float)):
            return np.array(obj).tobytes()
38

39
        if isinstance(obj, Image.Image):
40
41
            return cls.item_to_bytes(
                "image", np.asarray(convert_image_mode(obj, "RGBA")))
42
        if isinstance(obj, torch.Tensor):
43
            return cls.item_to_bytes("tensor", obj.numpy())
44
        if isinstance(obj, np.ndarray):
45
46
47
48
            return cls.item_to_bytes(
                "ndarray", {
                    "dtype": obj.dtype.str,
                    "shape": obj.shape,
49
                    "data": obj.tobytes(),
50
                })
51
52
53
54
55
56
57
58
59
60
61
62

        logger.warning(
            "No serialization method found for %s. "
            "Falling back to pickle.", type(obj))

        return pickle.dumps(obj)

    @classmethod
    def item_to_bytes(
        cls,
        key: str,
        obj: object,
63
64
65
66
67
68
69
70
    ) -> bytes:
        return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj))

    @classmethod
    def iter_item_to_bytes(
        cls,
        key: str,
        obj: object,
71
72
73
74
    ) -> Iterable[tuple[bytes, bytes]]:
        # Recursive cases
        if isinstance(obj, (list, tuple)):
            for i, elem in enumerate(obj):
75
                yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
76
77
        elif isinstance(obj, dict):
            for k, v in obj.items():
78
                yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
79
80
81
82
83
84
85
86
87
88
        else:
            key_bytes = cls.serialize_item(key)
            value_bytes = cls.serialize_item(obj)
            yield key_bytes, value_bytes

    @classmethod
    def hash_kwargs(cls, **kwargs: object) -> str:
        hasher = blake3()

        for k, v in kwargs.items():
89
            for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v):
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
                hasher.update(k_bytes)
                hasher.update(v_bytes)

        return hasher.hexdigest()

    @classmethod
    def hash_prompt_mm_data(
            cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
        """Hash multimodal data in the user input prompt if they exist."""

        if "multi_modal_data" not in prompt:
            return None

        mm_data = prompt["multi_modal_data"]
        if not mm_data:
            # mm_data can be None or an empty dict.
            return None

        mm_items = {
            modality: items if isinstance(items, list) else [items]
            for modality, items in mm_data.items()
        }

        mm_hashes = {
            modality: [cls.hash_kwargs(**{modality: item}) for item in items]
            for modality, items in mm_items.items()
        }

        return mm_hashes