loading_utils.py 3.78 KB
Newer Older
Dhruv Nair's avatar
Dhruv Nair committed
1
import os
2
3
import tempfile
from typing import Callable, List, Optional, Union
Dhruv Nair's avatar
Dhruv Nair committed
4
5
6
7
8

import PIL.Image
import PIL.ImageOps
import requests

9
10
from .import_utils import BACKENDS_MAPPING, is_opencv_available

Dhruv Nair's avatar
Dhruv Nair committed
11

12
def load_image(
13
    image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
14
) -> PIL.Image.Image:
Dhruv Nair's avatar
Dhruv Nair committed
15
16
17
18
19
20
    """
    Loads `image` to a PIL Image.

    Args:
        image (`str` or `PIL.Image.Image`):
            The image to convert to the PIL Image format.
21
        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
22
23
            A conversion method to apply to the image after loading it. When set to `None` the image will be converted
            "RGB".
24

Dhruv Nair's avatar
Dhruv Nair committed
25
26
27
28
29
30
31
32
33
34
35
    Returns:
        `PIL.Image.Image`:
            A PIL Image.
    """
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            image = PIL.Image.open(requests.get(image, stream=True).raw)
        elif os.path.isfile(image):
            image = PIL.Image.open(image)
        else:
            raise ValueError(
36
                f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
Dhruv Nair's avatar
Dhruv Nair committed
37
            )
38
39
    elif isinstance(image, PIL.Image.Image):
        image = image
Dhruv Nair's avatar
Dhruv Nair committed
40
41
    else:
        raise ValueError(
42
            "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
Dhruv Nair's avatar
Dhruv Nair committed
43
        )
44

Dhruv Nair's avatar
Dhruv Nair committed
45
    image = PIL.ImageOps.exif_transpose(image)
46
47
48
49
50
51

    if convert_method is not None:
        image = convert_method(image)
    else:
        image = image.convert("RGB")

Dhruv Nair's avatar
Dhruv Nair committed
52
    return image
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


def load_video(
    video: str,
    convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
    """
    Loads `video` to a list of PIL Image.

    Args:
        video (`str`):
            A URL or Path to a video to convert to a list of PIL Image format.
        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
            A conversion method to apply to the video after loading it. When set to `None` the images will be converted
            to "RGB".

    Returns:
        `List[PIL.Image.Image]`:
            The video as a list of PIL images.
    """
    is_url = video.startswith("http://") or video.startswith("https://")
    is_file = os.path.isfile(video)
    was_tempfile_created = False

    if not (is_url or is_file):
        raise ValueError(
            f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
        )

    if is_url:
        video_data = requests.get(video, stream=True).raw
        video_path = tempfile.NamedTemporaryFile(suffix=os.path.splitext(video)[1], delete=False).name
        was_tempfile_created = True
        with open(video_path, "wb") as f:
            f.write(video_data.read())

        video = video_path

    pil_images = []
    if video.endswith(".gif"):
        gif = PIL.Image.open(video)
        try:
            while True:
                pil_images.append(gif.copy())
                gif.seek(gif.tell() + 1)
        except EOFError:
            pass

    else:
        if is_opencv_available():
            import cv2
        else:
            raise ImportError(BACKENDS_MAPPING["opencv"][1].format("load_video"))

        video_capture = cv2.VideoCapture(video)
        success, frame = video_capture.read()
        while success:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_images.append(PIL.Image.fromarray(frame))
            success, frame = video_capture.read()

        video_capture.release()

    if was_tempfile_created:
        os.remove(video_path)

    if convert_method is not None:
        pil_images = convert_method(pil_images)

    return pil_images