loading_utils.py 5.18 KB
Newer Older
Dhruv Nair's avatar
Dhruv Nair committed
1
import os
2
import tempfile
3
from typing import Any, Callable, List, Optional, Tuple, Union
4
from urllib.parse import unquote, urlparse
Dhruv Nair's avatar
Dhruv Nair committed
5
6
7
8
9

import PIL.Image
import PIL.ImageOps
import requests

10
from .constants import DIFFUSERS_REQUEST_TIMEOUT
11
from .import_utils import BACKENDS_MAPPING, is_imageio_available
12

Dhruv Nair's avatar
Dhruv Nair committed
13

14
def load_image(
15
    image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
16
) -> PIL.Image.Image:
Dhruv Nair's avatar
Dhruv Nair committed
17
18
19
20
21
22
    """
    Loads `image` to a PIL Image.

    Args:
        image (`str` or `PIL.Image.Image`):
            The image to convert to the PIL Image format.
23
        convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*):
24
25
            A conversion method to apply to the image after loading it. When set to `None` the image will be converted
            "RGB".
26

Dhruv Nair's avatar
Dhruv Nair committed
27
28
29
30
31
32
    Returns:
        `PIL.Image.Image`:
            A PIL Image.
    """
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
33
            image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
Dhruv Nair's avatar
Dhruv Nair committed
34
35
36
37
        elif os.path.isfile(image):
            image = PIL.Image.open(image)
        else:
            raise ValueError(
38
                f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path."
Dhruv Nair's avatar
Dhruv Nair committed
39
            )
40
41
    elif isinstance(image, PIL.Image.Image):
        image = image
Dhruv Nair's avatar
Dhruv Nair committed
42
43
    else:
        raise ValueError(
44
            "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image."
Dhruv Nair's avatar
Dhruv Nair committed
45
        )
46

Dhruv Nair's avatar
Dhruv Nair committed
47
    image = PIL.ImageOps.exif_transpose(image)
48
49
50
51
52
53

    if convert_method is not None:
        image = convert_method(image)
    else:
        image = image.convert("RGB")

Dhruv Nair's avatar
Dhruv Nair committed
54
    return image
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


def load_video(
    video: str,
    convert_method: Optional[Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]]] = None,
) -> List[PIL.Image.Image]:
    """
    Loads `video` to a list of PIL Image.

    Args:
        video (`str`):
            A URL or Path to a video to convert to a list of PIL Image format.
        convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*):
            A conversion method to apply to the video after loading it. When set to `None` the images will be converted
            to "RGB".

    Returns:
        `List[PIL.Image.Image]`:
            The video as a list of PIL images.
    """
    is_url = video.startswith("http://") or video.startswith("https://")
    is_file = os.path.isfile(video)
    was_tempfile_created = False

    if not (is_url or is_file):
        raise ValueError(
            f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path."
        )

    if is_url:
85
86
87
88
89
90
91
92
        response = requests.get(video, stream=True)
        if response.status_code != 200:
            raise ValueError(f"Failed to download video. Status code: {response.status_code}")

        parsed_url = urlparse(video)
        file_name = os.path.basename(unquote(parsed_url.path))

        suffix = os.path.splitext(file_name)[1] or ".mp4"
93
        video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
94

95
        was_tempfile_created = True
96
97

        video_data = response.iter_content(chunk_size=8192)
98
        with open(video_path, "wb") as f:
99
100
            for chunk in video_data:
                f.write(chunk)
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        video = video_path

    pil_images = []
    if video.endswith(".gif"):
        gif = PIL.Image.open(video)
        try:
            while True:
                pil_images.append(gif.copy())
                gif.seek(gif.tell() + 1)
        except EOFError:
            pass

    else:
115
116
        if is_imageio_available():
            import imageio
117
        else:
118
            raise ImportError(BACKENDS_MAPPING["imageio"][1].format("load_video"))
119

120
121
122
123
124
125
        try:
            imageio.plugins.ffmpeg.get_exe()
        except AttributeError:
            raise AttributeError(
                "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg"
            )
126

127
128
129
130
        with imageio.get_reader(video) as reader:
            # Read all frames
            for frame in reader:
                pil_images.append(PIL.Image.fromarray(frame))
131
132
133
134
135
136
137
138

    if was_tempfile_created:
        os.remove(video_path)

    if convert_method is not None:
        pil_images = convert_method(pil_images)

    return pil_images
139
140
141
142
143
144
145
146
147
148
149
150
151


# Taken from `transformers`.
def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
    if "." in tensor_name:
        splits = tensor_name.split(".")
        for split in splits[:-1]:
            new_module = getattr(module, split)
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            module = new_module
        tensor_name = splits[-1]
    return module, tensor_name
152
153
154
155
156
157
158
159
160
161
162
163


def get_submodule_by_name(root_module, module_path: str):
    current = root_module
    parts = module_path.split(".")
    for part in parts:
        if part.isdigit():
            idx = int(part)
            current = current[idx]  # e.g., for nn.ModuleList or nn.Sequential
        else:
            current = getattr(current, part)
    return current