registry.py 927 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python


class Registry:

    def __init__(self):
        self._registry = dict()

    def register(self, name):
        assert name not in self._registry

12
        def _register(callable_):
13
14
            self._registry[name] = callable_

15
        return _register
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    def get_callable(self, name: str):
        return self._registry[name]

    def __iter__(self):
        self._idx = 0
        self._len = len(self._registry)
        self._names = list(self._registry.keys())
        return self

    def __next__(self):
        if self._idx < self._len:
            key = self._names[self._idx]
            callable_ = self._registry[key]
            self._idx += 1
            return callable_
        else:
            raise StopIteration


non_distributed_component_funcs = Registry()
37
model_parallel_component_funcs = Registry()
38

39
__all__ = ['non_distributed_component_funcs', 'model_parallel_component_funcs']