_utils.py 748 Bytes
Newer Older
1
import torch
2
3
from torch import Tensor
from torch.jit.annotations import List
4
5
6


def _cat(tensors, dim=0):
7
    # type: (List[Tensor], int) -> Tensor
8
9
10
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
11
12
    # TODO add back the assert
    # assert isinstance(tensors, (list, tuple))
13
14
15
16
17
18
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


def convert_boxes_to_roi_format(boxes):
19
    # type: (List[Tensor]) -> Tensor
20
    concat_boxes = _cat([b for b in boxes], dim=0)
21
22
23
24
    temp = []
    for i, b in enumerate(boxes):
        temp.append(torch.full_like(b[:, :1], i))
    ids = _cat(temp, dim=0)
25
26
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois