base.py 2.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass, field
6
from pathlib import Path
7
from typing import Any, Generic, TypeVar
8

9
10
import numpy as np

11
_T = TypeVar("_T")
12
13


14
15
16
17
18
19
20
21
22
23
@dataclass
class MediaWithBytes(Generic[_T]):
    """
    Wrapper that couples a media object with its original encoded bytes.

    This ensures the raw bytes and media object remain synchronized,
    preventing cache corruption from in-place modifications.

    The wrapper delegates attribute access to the underlying media object,
    making it behave transparently like the wrapped type (e.g., PIL.Image).
24
25

    NOTE: Currently, this wrapper is used only for the image modality.
26
27
28
    """

    media: _T
29
    original_bytes: bytes = field(repr=False)
30
31
32
33
34

    def __array__(self, *args, **kwargs) -> np.ndarray:
        """Allow np.array(obj) to return np.array(obj.media)."""
        return np.array(self.media, *args, **kwargs)

35
36
37
38
39
40
    def __getstate__(self):
        return self.__dict__.copy()

    def __setstate__(self, state: dict[str, Any]):
        self.__dict__.update(state)

41
42
43
44
45
    def __getattr__(self, name: str):
        """Delegate attribute access to the underlying media object."""
        return getattr(self.media, name)


46
class MediaIO(ABC, Generic[_T]):
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    """Configuration values can be user-provided either by --media-io-kwargs or
    by the runtime API field "media_io_kwargs". Ensure proper validation and
    error handling.
    """

    @classmethod
    def merge_kwargs(
        cls,
        default_kwargs: dict[str, Any] | None,
        runtime_kwargs: dict[str, Any] | None,
    ) -> dict[str, Any]:
        """Merge config-level kwargs and request-level kwargs.

        By default this performs a shallow merge where runtime kwargs override
        keys in default kwargs. Subclasses may override to apply modality-
        specific behavior.
        """
        merged = dict(default_kwargs or {})
        if runtime_kwargs:
            merged.update(runtime_kwargs)
        return merged

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError