base.py 1.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from dataclasses import dataclass
6
from pathlib import Path
7
from typing import Generic, TypeVar
8

9
10
import numpy as np

11
_T = TypeVar("_T")
12
13


14
15
16
17
18
19
20
21
22
23
@dataclass
class MediaWithBytes(Generic[_T]):
    """
    Wrapper that couples a media object with its original encoded bytes.

    This ensures the raw bytes and media object remain synchronized,
    preventing cache corruption from in-place modifications.

    The wrapper delegates attribute access to the underlying media object,
    making it behave transparently like the wrapped type (e.g., PIL.Image).
24
25

    NOTE: Currently, this wrapper is used only for the image modality.
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    """

    media: _T
    original_bytes: bytes

    def __array__(self, *args, **kwargs) -> np.ndarray:
        """Allow np.array(obj) to return np.array(obj.media)."""
        return np.array(self.media, *args, **kwargs)

    def __getattr__(self, name: str):
        """Delegate attribute access to the underlying media object."""
        # This is only called when the attribute is not found on self
        return getattr(self.media, name)


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class MediaIO(ABC, Generic[_T]):
    @abstractmethod
    def load_bytes(self, data: bytes) -> _T:
        raise NotImplementedError

    @abstractmethod
    def load_base64(self, media_type: str, data: str) -> _T:
        """
        List of media types:
        https://www.iana.org/assignments/media-types/media-types.xhtml
        """
        raise NotImplementedError

    @abstractmethod
    def load_file(self, filepath: Path) -> _T:
        raise NotImplementedError