image.py 2.18 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import base64
from io import BytesIO
from pathlib import Path
6
7
8
9

import torch
from PIL import Image

10
from .base import MediaIO
11
12
13
14
15
16
17
18
19
20
21
22


def rescale_image_size(image: Image.Image,
                       size_factor: float,
                       transpose: int = -1) -> Image.Image:
    """Rescale the dimensions of an image by a constant factor."""
    new_width = int(image.width * size_factor)
    new_height = int(image.height * size_factor)
    image = image.resize((new_width, new_height))
    if transpose >= 0:
        image = image.transpose(Image.Transpose(transpose))
    return image
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


class ImageMediaIO(MediaIO[Image.Image]):

    def __init__(self, *, image_mode: str = "RGB") -> None:
        super().__init__()

        self.image_mode = image_mode

    def load_bytes(self, data: bytes) -> Image.Image:
        image = Image.open(BytesIO(data))
        image.load()
        return image.convert(self.image_mode)

    def load_base64(self, media_type: str, data: str) -> Image.Image:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> Image.Image:
        image = Image.open(filepath)
        image.load()
        return image.convert(self.image_mode)

    def encode_base64(
        self,
        media: Image.Image,
        *,
        image_format: str = "JPEG",
    ) -> str:
        image = media

        with BytesIO() as buffer:
            image = image.convert(self.image_mode)
            image.save(buffer, image_format)
            data = buffer.getvalue()

        return base64.b64encode(data).decode('utf-8')
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):

    def __init__(self) -> None:
        super().__init__()

    def load_bytes(self, data: bytes) -> torch.Tensor:
        buffer = BytesIO(data)
        return torch.load(buffer, weights_only=True)

    def load_base64(self, media_type: str, data: str) -> torch.Tensor:
        return self.load_bytes(base64.b64decode(data))

    def load_file(self, filepath: Path) -> torch.Tensor:
cyyever's avatar
cyyever committed
74
        return torch.load(filepath, weights_only=True)
75
76
77

    def encode_base64(self, media: torch.Tensor) -> str:
        return base64.b64encode(media.numpy()).decode('utf-8')