"vscode:/vscode.git/clone" did not exist on "6356cbdc4f652aa6004b08b0fc05f6a4ec9d0b97"
_utils.py 2.26 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
import torch
4
from torch import nn, Tensor
5
6


7
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
8
9
10
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
11
12
    # TODO add back the assert
    # assert isinstance(tensors, (list, tuple))
13
14
15
16
17
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


18
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
19
    concat_boxes = _cat([b for b in boxes], dim=0)
20
21
22
23
    temp = []
    for i, b in enumerate(boxes):
        temp.append(torch.full_like(b[:, :1], i))
    ids = _cat(temp, dim=0)
24
25
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois
26
27


28
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
29
    if isinstance(boxes, (list, tuple)):
30
        for _tensor in boxes:
31
32
33
            torch._assert(
                _tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
            )
34
    elif isinstance(boxes, torch.Tensor):
35
        torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]")
36
    else:
37
        torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]")
38
    return
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def split_normalization_params(
    model: nn.Module, norm_classes: Optional[List[type]] = None
) -> Tuple[List[Tensor], List[Tensor]]:
    # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
    if not norm_classes:
        norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

    for t in norm_classes:
        if not issubclass(t, nn.Module):
            raise ValueError(f"Class {t} is not a subclass of nn.Module.")

    classes = tuple(norm_classes)

    norm_params = []
    other_params = []
    for module in model.modules():
        if next(module.children(), None):
            other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
        elif isinstance(module, classes):
            norm_params.extend(p for p in module.parameters() if p.requires_grad)
        else:
            other_params.extend(p for p in module.parameters() if p.requires_grad)
    return norm_params, other_params