loading_utils.py 3.93 KB
Newer Older
Dhruv Nair's avatar
Dhruv Nair committed
1
import os
2
3
import tempfile
from typing import Callable, List, Optional, Union
Dhruv Nair's avatar
Dhruv Nair committed
4
5
6
7
8

import PIL.Image
import PIL.ImageOps
import requests

9
from .import_utils import BACKENDS_MAPPING, is_imageio_available
10

Dhruv Nair's avatar
Dhruv Nair committed
11

12
def load_image(
13
    image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
14
) -> PIL.Image.Image:
Dhruv Nair's avatar
Dhruv Nair committed
15
16
17
18
19
20
    """
    Loads `image` to a PIL Image.

    Args:
        image (`str` or `PIL.Image.Image`):
            The image to convert to the PIL Image format.
21
        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
22
23
            A conversion method to apply to the image after loading it. When set to `None` the image will be converted
            "RGB".
24

Dhruv Nair's avatar
Dhruv Nair committed
25
26
27
28
29
30
31
32
33
34
35
    Returns:
        `PIL.Image.Image`:
            A PIL Image.
    """
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            image = PIL.Image.open(requests.get(image, stream=True).raw)
        elif os.path.isfile(image):
            image = PIL.Image.open(image)
        else:
            raise ValueError(
36
                f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
Dhruv Nair's avatar
Dhruv Nair committed
37
            )
38
39
    elif isinstance(image, PIL.Image.Image):
        image = image
Dhruv Nair's avatar
Dhruv Nair committed
40
41
    else:
        raise ValueError(
42
            "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
Dhruv Nair's avatar
Dhruv Nair committed
43
        )
44

Dhruv Nair's avatar
Dhruv Nair committed
45
    image = PIL.ImageOps.exif_transpose(image)
46
47
48
49
50
51

    if convert_method is not None:
        image = convert_method(image)
    else:
        image = image.convert("RGB")

Dhruv Nair's avatar
Dhruv Nair committed
52
    return image
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83


def load_video(
    video: str,
    convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
    """
    Loads `video` to a list of PIL Image.

    Args:
        video (`str`):
            A URL or Path to a video to convert to a list of PIL Image format.
        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
            A conversion method to apply to the video after loading it. When set to `None` the images will be converted
            to "RGB".

    Returns:
        `List[PIL.Image.Image]`:
            The video as a list of PIL images.
    """
    is_url = video.startswith("http://") or video.startswith("https://")
    is_file = os.path.isfile(video)
    was_tempfile_created = False

    if not (is_url or is_file):
        raise ValueError(
            f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
        )

    if is_url:
        video_data = requests.get(video, stream=True).raw
84
85
        suffix = os.path.splitext(video)[1] or ".mp4"
        video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        was_tempfile_created = True
        with open(video_path, "wb") as f:
            f.write(video_data.read())

        video = video_path

    pil_images = []
    if video.endswith(".gif"):
        gif = PIL.Image.open(video)
        try:
            while True:
                pil_images.append(gif.copy())
                gif.seek(gif.tell() + 1)
        except EOFError:
            pass

    else:
103
104
        if is_imageio_available():
            import imageio
105
        else:
106
            raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
107

108
109
110
111
112
113
        try:
            imageio.plugins.ffmpeg.get_exe()
        except AttributeError:
            raise AttributeError(
                "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
            )
114

115
116
117
118
        with imageio.get_reader(video) as reader:
            # Read all frames
            for frame in reader:
                pil_images.append(PIL.Image.fromarray(frame))
119
120
121
122
123
124
125
126

    if was_tempfile_created:
        os.remove(video_path)

    if convert_method is not None:
        pil_images = convert_method(pil_images)

    return pil_images