"model/vscode:/vscode.git/clone" did not exist on "470af8ab899aca6a72571f0c1e2ac6f9049aca29"
_utils.py 2.29 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
import torch
4
from torch import nn, Tensor
5
6


7
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
8
9
10
    """
    Efficient version of torch.cat that avoids a copy if there is only a single element in a list
    """
11
12
    # TODO add back the assert
    # assert isinstance(tensors, (list, tuple))
13
14
15
16
17
    if len(tensors) == 1:
        return tensors[0]
    return torch.cat(tensors, dim)


18
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
19
    concat_boxes = _cat([b for b in boxes], dim=0)
20
21
22
23
    temp = []
    for i, b in enumerate(boxes):
        temp.append(torch.full_like(b[:, :1], i))
    ids = _cat(temp, dim=0)
24
25
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois
26
27


28
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
29
    if isinstance(boxes, (list, tuple)):
30
        for _tensor in boxes:
31
32
            if _tensor.size(1) != 4:
                raise ValueError("The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]].")
33
    elif isinstance(boxes, torch.Tensor):
34
35
        if boxes.size(1) != 5:
            raise ValueError("The boxes tensor shape is not correct as Tensor[K, 5]/")
36
    else:
37
        raise TypeError(f"boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]], instead got {type(boxes)}")
38
    return
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def split_normalization_params(
    model: nn.Module, norm_classes: Optional[List[type]] = None
) -> Tuple[List[Tensor], List[Tensor]]:
    # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
    if not norm_classes:
        norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

    for t in norm_classes:
        if not issubclass(t, nn.Module):
            raise ValueError(f"Class {t} is not a subclass of nn.Module.")

    classes = tuple(norm_classes)

    norm_params = []
    other_params = []
    for module in model.modules():
        if next(module.children(), None):
            other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
        elif isinstance(module, classes):
            norm_params.extend(p for p in module.parameters() if p.requires_grad)
        else:
            other_params.extend(p for p in module.parameters() if p.requires_grad)
    return norm_params, other_params