builder.py 1.33 KB
Newer Older
dingchang's avatar
dingchang committed
1
# Copyright (c) OpenMMLab. All rights reserved.
2
3
from typing import Union

4
from mmengine.registry import Registry
5
from torch import nn as nn
6
7

SA_MODULES = Registry('point_sa_module')
8
9


10
def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module:
11
12
13
    """Build PointNet2 set abstraction (SA) module.

    Args:
14
15
        cfg (dict or None): The SA module config, which should contain:

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
            - type (str): Module type.
            - module args: Args needed to instantiate an SA module.
        args (argument list): Arguments passed to the `__init__`
            method of the corresponding module.
        kwargs (keyword arguments): Keyword arguments passed to the `__init__`
            method of the corresponding SA module .

    Returns:
        nn.Module: Created SA module.
    """
    if cfg is None:
        cfg_ = dict(type='PointSAModule')
    else:
        if not isinstance(cfg, dict):
            raise TypeError('cfg must be a dict')
        if 'type' not in cfg:
            raise KeyError('the cfg dict must contain the key "type"')
        cfg_ = cfg.copy()

    module_type = cfg_.pop('type')
    if module_type not in SA_MODULES:
        raise KeyError(f'Unrecognized module type {module_type}')
    else:
        sa_module = SA_MODULES.get(module_type)

    module = sa_module(*args, **kwargs, **cfg_)

    return module