_utils.py 1.22 KB
Newer Older
1
2
from typing import List, Union

3
import torch
4
from torch import Tensor
5
6


7
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
8
9
10
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
11
12
    # TODO add back the assert
    # assert isinstance(tensors, (list, tuple))
13
14
15
16
17
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


18
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
19
    concat_boxes = _cat([b for b in boxes], dim=0)
20
21
22
23
    temp = []
    for i, b in enumerate(boxes):
        temp.append(torch.full_like(b[:, :1], i))
    ids = _cat(temp, dim=0)
24
25
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois
26
27


28
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
29
    if isinstance(boxes, (list, tuple)):
30
        for _tensor in boxes:
31
32
33
            assert (
                _tensor.size(1) == 4
            ), "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
34
    elif isinstance(boxes, torch.Tensor):
35
        assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]"
36
    else:
37
        assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]"
38
    return