hasher.py 4.01 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pickle
5
import uuid
6
from collections.abc import Iterable
7
8
9
10
11
12
13
14

import numpy as np
import torch
from blake3 import blake3
from PIL import Image

from vllm.logger import init_logger

15
16
from .base import MediaWithBytes

17
18
19
20
21
logger = init_logger(__name__)


class MultiModalHasher:
    @classmethod
22
    def serialize_item(cls, obj: object) -> Iterable[bytes | memoryview]:
23
        # Simple cases
24
        if isinstance(obj, (bytes, memoryview)):
25
            return (obj,)
26
        if isinstance(obj, str):
27
            return (obj.encode("utf-8"),)
28
        if isinstance(obj, (int, float)):
29
            return (np.array(obj).tobytes(),)
30

31
        if isinstance(obj, Image.Image):
32
33
            exif = obj.getexif()
            if Image.ExifTags.Base.ImageID in exif and isinstance(
34
35
36
                exif[Image.ExifTags.Base.ImageID], uuid.UUID
            ):
                return (exif[Image.ExifTags.Base.ImageID].bytes,)
37

38
            data = {"mode": obj.mode, "data": np.asarray(obj)}
39
40
41
42
43
44
            palette = obj.palette
            if palette is not None:
                data["palette"] = palette.palette
                if palette.rawmode is not None:
                    data["palette_rawmode"] = palette.rawmode

45
            return cls.iter_item_to_bytes("image", data)
46
47
48
49
50
51
52
53
54
55

        if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
            exif = obj.media.getexif()
            if Image.ExifTags.Base.ImageID in exif and isinstance(
                exif[Image.ExifTags.Base.ImageID], uuid.UUID
            ):
                return (exif[Image.ExifTags.Base.ImageID].bytes,)

            return cls.iter_item_to_bytes("image", obj.original_bytes)

56
        if isinstance(obj, torch.Tensor):
57
58
            tensor_obj: torch.Tensor = obj.cpu()
            tensor_dtype = tensor_obj.dtype
59
60
61
62
            tensor_shape = tensor_obj.shape

            # NumPy does not support bfloat16.
            # Workaround: View the tensor as a contiguous 1D array of bytes
63
64
            if tensor_dtype == torch.bfloat16:
                tensor_obj = tensor_obj.contiguous()
65
                tensor_obj = tensor_obj.view((tensor_obj.numel(),)).view(torch.uint8)
66

67
                return cls.iter_item_to_bytes(
68
69
                    "tensor",
                    {
70
                        "original_dtype": str(tensor_dtype),
71
72
                        "original_shape": tuple(tensor_shape),
                        "data": tensor_obj.numpy(),
73
74
                    },
                )
75
            return cls.iter_item_to_bytes("tensor", tensor_obj.numpy())
76
        if isinstance(obj, np.ndarray):
77
            # If the array is non-contiguous, we need to copy it first
78
79
80
81
82
83
84
85
86
87
88
            arr_data = (
                obj.view(np.uint8).data if obj.flags.c_contiguous else obj.tobytes()
            )
            return cls.iter_item_to_bytes(
                "ndarray",
                {
                    "dtype": obj.dtype.str,
                    "shape": obj.shape,
                    "data": arr_data,
                },
            )
89
        logger.warning(
90
91
            "No serialization method found for %s. Falling back to pickle.", type(obj)
        )
92

93
        return (pickle.dumps(obj),)
94
95
96
97
98
99

    @classmethod
    def iter_item_to_bytes(
        cls,
        key: str,
        obj: object,
100
    ) -> Iterable[bytes | memoryview]:
101
102
103
        # Recursive cases
        if isinstance(obj, (list, tuple)):
            for i, elem in enumerate(obj):
104
                yield from cls.iter_item_to_bytes(f"{key}.{i}", elem)
105
106
        elif isinstance(obj, dict):
            for k, v in obj.items():
107
                yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
108
        else:
109
110
            yield key.encode("utf-8")
            yield from cls.serialize_item(obj)
111
112
113
114
115
116

    @classmethod
    def hash_kwargs(cls, **kwargs: object) -> str:
        hasher = blake3()

        for k, v in kwargs.items():
117
118
            for bytes_ in cls.iter_item_to_bytes(k, v):
                hasher.update(bytes_)
119
120

        return hasher.hexdigest()