utils.py 2.74 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
from typing import Any, Callable, List, Tuple, Type, Union

import PIL.Image
6
from torchvision import datapoints
7
8

from torchvision._utils import sequence_to_str
Philip Meier's avatar
Philip Meier committed
9
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor
10
11


12
13
def query_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
    bounding_boxes = [inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes)]
14
    if not bounding_boxes:
Philip Meier's avatar
Philip Meier committed
15
        raise TypeError("No bounding boxes were found in the sample")
16
    elif len(bounding_boxes) > 1:
Philip Meier's avatar
Philip Meier committed
17
        raise ValueError("Found multiple bounding boxes instances in the sample")
18
19
20
21
22
23
24
    return bounding_boxes.pop()


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
    chws = {
        tuple(get_dimensions(inpt))
        for inpt in flat_inputs
Philip Meier's avatar
Philip Meier committed
25
        if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
26
27
28
29
30
31
32
33
34
    }
    if not chws:
        raise TypeError("No image or video was found in the sample")
    elif len(chws) > 1:
        raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
    c, h, w = chws.pop()
    return c, h, w


Philip Meier's avatar
Philip Meier committed
35
def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
36
    sizes = {
Philip Meier's avatar
Philip Meier committed
37
        tuple(get_size(inpt))
38
        for inpt in flat_inputs
Philip Meier's avatar
Philip Meier committed
39
40
41
42
43
44
45
46
47
48
        if check_type(
            inpt,
            (
                is_simple_tensor,
                datapoints.Image,
                PIL.Image.Image,
                datapoints.Video,
                datapoints.Mask,
                datapoints.BoundingBoxes,
            ),
49
        )
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    }
    if not sizes:
        raise TypeError("No image, video, mask or bounding box was found in the sample")
    elif len(sizes) > 1:
        raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
    h, w = sizes.pop()
    return h, w


def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
    for type_or_check in types_or_checks:
        if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
            return True
    return False


def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
    for inpt in flat_inputs:
        if check_type(inpt, types_or_checks):
            return True
    return False


def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
    for type_or_check in types_or_checks:
        for inpt in flat_inputs:
            if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
                break
        else:
            return False
    return True