hasher.py 3.63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pickle
5
import uuid
6
from collections.abc import Iterable
7
from typing import Union
8
9
10
11
12
13
14

import numpy as np
import torch
from blake3 import blake3
from PIL import Image

from vllm.logger import init_logger
15
from vllm.multimodal.image import convert_image_mode
16
17
18
19
20
21
22

logger = init_logger(__name__)


class MultiModalHasher:

    @classmethod
23
    def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
24
25
26
        # Simple cases
        if isinstance(obj, str):
            return obj.encode("utf-8")
27
        if isinstance(obj, (bytes, memoryview)):
28
            return obj
29
30
        if isinstance(obj, (int, float)):
            return np.array(obj).tobytes()
31

32
        if isinstance(obj, Image.Image):
33
34
35
36
37
            exif = obj.getexif()
            if Image.ExifTags.Base.ImageID in exif and isinstance(
                    exif[Image.ExifTags.Base.ImageID], uuid.UUID):
                # If the image has exif ImageID tag, use that
                return exif[Image.ExifTags.Base.ImageID].bytes
38
39
            return cls.item_to_bytes(
                "image", np.asarray(convert_image_mode(obj, "RGBA")))
40
        if isinstance(obj, torch.Tensor):
41
42
            tensor_obj: torch.Tensor = obj.cpu()
            tensor_dtype = tensor_obj.dtype
43
44
45
46
            tensor_shape = tensor_obj.shape

            # NumPy does not support bfloat16.
            # Workaround: View the tensor as a contiguous 1D array of bytes
47
48
49
50
            if tensor_dtype == torch.bfloat16:
                tensor_obj = tensor_obj.contiguous()
                tensor_obj = tensor_obj.view(
                    (tensor_obj.numel(), )).view(torch.uint8)
51

52
53
54
                return cls.item_to_bytes(
                    "tensor", {
                        "original_dtype": str(tensor_dtype),
55
56
                        "original_shape": tuple(tensor_shape),
                        "data": tensor_obj.numpy(),
57
                    })
58

59
            return cls.item_to_bytes("tensor", tensor_obj.numpy())
60
        if isinstance(obj, np.ndarray):
61
62
63
64
65
66
67
            # If the array is non-contiguous, we need to copy it first
            arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
            return cls.item_to_bytes("ndarray", {
                "dtype": obj.dtype.str,
                "shape": obj.shape,
                "data": arr_data,
            })
68
69
70
71
72
73
74
75
76
77
78
79

        logger.warning(
            "No serialization method found for %s. "
            "Falling back to pickle.", type(obj))

        return pickle.dumps(obj)

    @classmethod
    def item_to_bytes(
        cls,
        key: str,
        obj: object,
80
81
82
83
84
85
86
87
    ) -> bytes:
        return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj))

    @classmethod
    def iter_item_to_bytes(
        cls,
        key: str,
        obj: object,
88
    ) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]:
89
90
91
        # Recursive cases
        if isinstance(obj, (list, tuple)):
            for i, elem in enumerate(obj):
92
                yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
93
94
        elif isinstance(obj, dict):
            for k, v in obj.items():
95
                yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
96
        else:
97
            key_bytes = key.encode("utf-8")
98
99
100
101
102
103
104
105
            value_bytes = cls.serialize_item(obj)
            yield key_bytes, value_bytes

    @classmethod
    def hash_kwargs(cls, **kwargs: object) -> str:
        hasher = blake3()

        for k, v in kwargs.items():
106
            for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v):
107
108
109
110
                hasher.update(k_bytes)
                hasher.update(v_bytes)

        return hasher.hexdigest()