utils.py 2.67 KB
Newer Older
1
2
from __future__ import annotations

3
4
5
from typing import Any, Callable, List, Tuple, Type, Union

import PIL.Image
6
from torchvision import datapoints
7
8

from torchvision._utils import sequence_to_str
9
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
10
11


12
13
14
15
16
17
def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes:
    # This assumes there is only one bbox per sample as per the general convention
    try:
        return next(inpt for inpt in flat_inputs if isinstance(inpt, datapoints.BoundingBoxes))
    except StopIteration:
        raise ValueError("No bounding boxes were found in the sample")
18
19
20
21
22
23


def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]:
    chws = {
        tuple(get_dimensions(inpt))
        for inpt in flat_inputs
24
        if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video))
25
26
27
28
29
30
31
32
33
    }
    if not chws:
        raise TypeError("No image or video was found in the sample")
    elif len(chws) > 1:
        raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
    c, h, w = chws.pop()
    return c, h, w


Philip Meier's avatar
Philip Meier committed
34
def query_size(flat_inputs: List[Any]) -> Tuple[int, int]:
35
    sizes = {
Philip Meier's avatar
Philip Meier committed
36
        tuple(get_size(inpt))
37
        for inpt in flat_inputs
Philip Meier's avatar
Philip Meier committed
38
39
40
        if check_type(
            inpt,
            (
41
                is_pure_tensor,
Philip Meier's avatar
Philip Meier committed
42
43
44
45
46
47
                datapoints.Image,
                PIL.Image.Image,
                datapoints.Video,
                datapoints.Mask,
                datapoints.BoundingBoxes,
            ),
48
        )
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    }
    if not sizes:
        raise TypeError("No image, video, mask or bounding box was found in the sample")
    elif len(sizes) > 1:
        raise ValueError(f"Found multiple HxW dimensions in the sample: {sequence_to_str(sorted(sizes))}")
    h, w = sizes.pop()
    return h, w


def check_type(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:
    for type_or_check in types_or_checks:
        if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj):
            return True
    return False


def has_any(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
    for inpt in flat_inputs:
        if check_type(inpt, types_or_checks):
            return True
    return False


def has_all(flat_inputs: List[Any], *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool:
    for type_or_check in types_or_checks:
        for inpt in flat_inputs:
            if isinstance(inpt, type_or_check) if isinstance(type_or_check, type) else type_or_check(inpt):
                break
        else:
            return False
    return True