registry.py 1.14 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn


class Registry(object):

    def __init__(self, name):
        self._name = name
        self._module_dict = dict()

    @property
    def name(self):
        return self._name

    @property
    def module_dict(self):
        return self._module_dict

    def _register_module(self, module_class):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not issubclass(module_class, nn.Module):
myownskyW7's avatar
myownskyW7 committed
25
26
            raise TypeError('module must be a child of nn.Module, but got {}'.
                            format(module_class))
Kai Chen's avatar
Kai Chen committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        module_name = module_class.__name__
        if module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        self._module_dict[module_name] = module_class

    def register_module(self, cls):
        self._register_module(cls)
        return cls


BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
myownskyW7's avatar
myownskyW7 committed
41
SHARED_HEADS = Registry('shared_head')
Kai Chen's avatar
Kai Chen committed
42
43
HEADS = Registry('head')
DETECTORS = Registry('detector')