_utils.py 1.19 KB
Newer Older
1
import torch
2
from torch import Tensor
vfdev's avatar
vfdev committed
3
from torch.jit.annotations import List
4
5


6
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
7
8
9
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
10
11
    # TODO add back the assert
    # assert isinstance(tensors, (list, tuple))
12
13
14
15
16
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


17
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
18
    concat_boxes = _cat([b for b in boxes], dim=0)
19
20
21
22
    temp = []
    for i, b in enumerate(boxes):
        temp.append(torch.full_like(b[:, :1], i))
    ids = _cat(temp, dim=0)
23
24
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois
25
26


27
def check_roi_boxes_shape(boxes: Tensor):
28
    if isinstance(boxes, (list, tuple)):
29
30
31
32
33
34
35
36
        for _tensor in boxes:
            assert _tensor.size(1) == 4, \
                'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]'
    elif isinstance(boxes, torch.Tensor):
        assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]'
    else:
        assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
    return