loading_utils.py 4.75 KB
Newer Older
Dhruv Nair's avatar
Dhruv Nair committed
1
import os
2
import tempfile
3
from typing import Any, Callable, List, Optional, Tuple, Union
4
from urllib.parse import unquote, urlparse
Dhruv Nair's avatar
Dhruv Nair committed
5
6
7
8
9

import PIL.Image
import PIL.ImageOps
import requests

10
from .import_utils import BACKENDS_MAPPING, is_imageio_available
11

Dhruv Nair's avatar
Dhruv Nair committed
12

13
def load_image(
14
    image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
15
) -> PIL.Image.Image:
Dhruv Nair's avatar
Dhruv Nair committed
16
17
18
19
20
21
    """
    Loads `image` to a PIL Image.

    Args:
        image (`str` or `PIL.Image.Image`):
            The image to convert to the PIL Image format.
22
        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
23
24
            A conversion method to apply to the image after loading it. When set to `None` the image will be converted
            "RGB".
25

Dhruv Nair's avatar
Dhruv Nair committed
26
27
28
29
30
31
32
33
34
35
36
    Returns:
        `PIL.Image.Image`:
            A PIL Image.
    """
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            image = PIL.Image.open(requests.get(image, stream=True).raw)
        elif os.path.isfile(image):
            image = PIL.Image.open(image)
        else:
            raise ValueError(
37
                f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
Dhruv Nair's avatar
Dhruv Nair committed
38
            )
39
40
    elif isinstance(image, PIL.Image.Image):
        image = image
Dhruv Nair's avatar
Dhruv Nair committed
41
42
    else:
        raise ValueError(
43
            "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
Dhruv Nair's avatar
Dhruv Nair committed
44
        )
45

Dhruv Nair's avatar
Dhruv Nair committed
46
    image = PIL.ImageOps.exif_transpose(image)
47
48
49
50
51
52

    if convert_method is not None:
        image = convert_method(image)
    else:
        image = image.convert("RGB")

Dhruv Nair's avatar
Dhruv Nair committed
53
    return image
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83


def load_video(
    video: str,
    convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
    """
    Loads `video` to a list of PIL Image.

    Args:
        video (`str`):
            A URL or Path to a video to convert to a list of PIL Image format.
        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
            A conversion method to apply to the video after loading it. When set to `None` the images will be converted
            to "RGB".

    Returns:
        `List[PIL.Image.Image]`:
            The video as a list of PIL images.
    """
    is_url = video.startswith("http://") or video.startswith("https://")
    is_file = os.path.isfile(video)
    was_tempfile_created = False

    if not (is_url or is_file):
        raise ValueError(
            f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
        )

    if is_url:
84
85
86
87
88
89
90
91
        response = requests.get(video, stream=True)
        if response.status_code != 200:
            raise ValueError(f"Failed to download video. Status code: {response.status_code}")

        parsed_url = urlparse(video)
        file_name = os.path.basename(unquote(parsed_url.path))

        suffix = os.path.splitext(file_name)[1] or ".mp4"
92
        video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
93

94
        was_tempfile_created = True
95
96

        video_data = response.iter_content(chunk_size=8192)
97
        with open(video_path, "wb") as f:
98
99
            for chunk in video_data:
                f.write(chunk)
100
101
102
103
104
105
106
107
108
109
110
111
112
113

        video = video_path

    pil_images = []
    if video.endswith(".gif"):
        gif = PIL.Image.open(video)
        try:
            while True:
                pil_images.append(gif.copy())
                gif.seek(gif.tell() + 1)
        except EOFError:
            pass

    else:
114
115
        if is_imageio_available():
            import imageio
116
        else:
117
            raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
118

119
120
121
122
123
124
        try:
            imageio.plugins.ffmpeg.get_exe()
        except AttributeError:
            raise AttributeError(
                "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
            )
125

126
127
128
129
        with imageio.get_reader(video) as reader:
            # Read all frames
            for frame in reader:
                pil_images.append(PIL.Image.fromarray(frame))
130
131
132
133
134
135
136
137

    if was_tempfile_created:
        os.remove(video_path)

    if convert_method is not None:
        pil_images = convert_method(pil_images)

    return pil_images
138
139
140
141
142
143
144
145
146
147
148
149
150


# Taken from `transformers`.
def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
    if "." in tensor_name:
        splits = tensor_name.split(".")
        for split in splits[:-1]:
            new_module = getattr(module, split)
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            module = new_module
        tensor_name = splits[-1]
    return module, tensor_name