base.py 3.51 KB
Newer Older
zhangwenwei's avatar
zhangwenwei committed
1
2
3
4
5
from abc import ABCMeta, abstractmethod

import torch.nn as nn


zhangwenwei's avatar
zhangwenwei committed
6
class Base3DDetector(nn.Module, metaclass=ABCMeta):
zhangwenwei's avatar
zhangwenwei committed
7
8
9
    """Base class for detectors"""

    def __init__(self):
zhangwenwei's avatar
zhangwenwei committed
10
        super(Base3DDetector, self).__init__()
zhangwenwei's avatar
zhangwenwei committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        self.fp16_enabled = False

    @property
    def with_neck(self):
        return hasattr(self, 'neck') and self.neck is not None

    @property
    def with_shared_head(self):
        return hasattr(self, 'shared_head') and self.shared_head is not None

    @property
    def with_bbox(self):
        return hasattr(self, 'bbox_head') and self.bbox_head is not None

    @property
    def with_mask(self):
        return hasattr(self, 'mask_head') and self.mask_head is not None

    @abstractmethod
    def extract_feat(self, imgs):
        pass

    def extract_feats(self, imgs):
        assert isinstance(imgs, list)
        for img in imgs:
            yield self.extract_feat(img)

    @abstractmethod
    def forward_train(self, **kwargs):
        pass

    @abstractmethod
    def simple_test(self, **kwargs):
        pass

    @abstractmethod
    def aug_test(self, **kwargs):
        pass

    def init_weights(self, pretrained=None):
        if pretrained is not None:
zhangwenwei's avatar
zhangwenwei committed
52
            from mmdet3d.utils import get_root_logger
zhangwenwei's avatar
zhangwenwei committed
53
54
55
            logger = get_root_logger()
            logger.info('load model from: {}'.format(pretrained))

zhangwenwei's avatar
zhangwenwei committed
56
    def forward_test(self, points, img_metas, imgs=None, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
57
58
        """
        Args:
zhangwenwei's avatar
zhangwenwei committed
59
60
61
62
            points (List[Tensor]): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxC,
                which contains all points in the batch.
            img_metas (List[List[dict]]): the outer list indicates test-time
zhangwenwei's avatar
zhangwenwei committed
63
64
                augs (multiscale, flip, etc.) and the inner list indicates
                images in a batch
zhangwenwei's avatar
zhangwenwei committed
65
66
67
            imgs (List[Tensor], optional): the outer list indicates test-time
                augmentations and inner Tensor should have a shape NxCxHxW,
                which contains all images in the batch. Defaults to None.
zhangwenwei's avatar
zhangwenwei committed
68
        """
zhangwenwei's avatar
zhangwenwei committed
69
        for var, name in [(points, 'points'), (img_metas, 'img_metas')]:
zhangwenwei's avatar
zhangwenwei committed
70
71
72
73
            if not isinstance(var, list):
                raise TypeError('{} must be a list, but got {}'.format(
                    name, type(var)))

zhangwenwei's avatar
zhangwenwei committed
74
        num_augs = len(points)
zhangwenwei's avatar
zhangwenwei committed
75
76
77
        if num_augs != len(img_metas):
            raise ValueError(
                'num of augmentations ({}) != num of image meta ({})'.format(
zhangwenwei's avatar
zhangwenwei committed
78
                    len(points), len(img_metas)))
zhangwenwei's avatar
zhangwenwei committed
79
        # TODO: remove the restriction of imgs_per_gpu == 1 when prepared
zhangwenwei's avatar
zhangwenwei committed
80
81
        samples_per_gpu = len(points[0])
        assert samples_per_gpu == 1
zhangwenwei's avatar
zhangwenwei committed
82
83

        if num_augs == 1:
zhangwenwei's avatar
zhangwenwei committed
84
85
            imgs = [imgs] if imgs is None else imgs
            return self.simple_test(points[0], img_metas[0], imgs[0], **kwargs)
zhangwenwei's avatar
zhangwenwei committed
86
        else:
zhangwenwei's avatar
zhangwenwei committed
87
            return self.aug_test(points, img_metas, imgs, **kwargs)
zhangwenwei's avatar
zhangwenwei committed
88

zhangwenwei's avatar
zhangwenwei committed
89
    def forward(self, return_loss=True, **kwargs):
zhangwenwei's avatar
zhangwenwei committed
90
91
92
        """
        Calls either forward_train or forward_test depending on whether
        return_loss=True. Note this setting will change the expected inputs.
zhangwenwei's avatar
zhangwenwei committed
93
94
95
96
97
        When `return_loss=True`, img and img_metas are single-nested (i.e.
        Tensor and List[dict]), and when `resturn_loss=False`, img and
        img_metas should be double nested
        (i.e.  List[Tensor], List[List[dict]]), with the outer list
        indicating test time augmentations.
zhangwenwei's avatar
zhangwenwei committed
98
99
        """
        if return_loss:
zhangwenwei's avatar
zhangwenwei committed
100
            return self.forward_train(**kwargs)
zhangwenwei's avatar
zhangwenwei committed
101
        else:
zhangwenwei's avatar
zhangwenwei committed
102
            return self.forward_test(**kwargs)