registry.py 3.81 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import re
from typing import (
    Callable,
    Dict,
    Generator,
    Generic,
    Literal,
    Optional,
    Tuple,
    Type,
    TypeVar,
    overload,
)

T = TypeVar("T")
R = TypeVar("R")


class BaseRegistry(Generic[T]):
    """A registry for objects."""

    _registry_of_registries: Dict[str, Type["BaseRegistry"]] = {}
    _registry_storage: Dict[str, Tuple[T, Optional[str]]]

    @classmethod
    def _add_to_registry_of_registries(cls) -> None:
        name = cls.__name__
        if name not in cls._registry_of_registries:
            cls._registry_of_registries[name] = cls

    @classmethod
    def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]:
        """Yield all registries in the registry of registries."""
        yield from sorted(cls._registry_of_registries.items())

    @classmethod
    def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]:
        if not hasattr(cls, "_registry_storage"):
            cls._registry_storage = {}
        return cls._registry_storage  # pyright: ignore

    @classmethod
    def items(cls) -> Generator[Tuple[str, T], None, None]:
        """Yield all items in the registry."""
        yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items())

    @classmethod
    def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]:
        """Yield all items in the registry with their descriptions."""
        yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items())

    @classmethod
    def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]:
        """Add a class to the registry."""

        # Add the registry to the registry of registries
        cls._add_to_registry_of_registries()

        def _add(
            inner_self: T,
            inner_name: str = name,
            inner_desc: Optional[str] = desc,
            inner_cls: Type[BaseRegistry] = cls,
        ) -> T:
            """Add a tagger to the registry using tagger_name as the name."""

            existing = inner_cls.get(inner_name, raise_on_missing=False)

            if existing and existing != inner_self:
                if inner_self.__module__ == "__main__":
                    return inner_self

                raise ValueError(f"Tagger {inner_name} already exists")
            inner_cls._get_storage()[inner_name] = (inner_self, inner_desc)
            return inner_self

        return _add  # type: ignore

    @classmethod
    def remove(cls, name: str) -> bool:
        """Remove a tagger from the registry."""
        if name in cls._get_storage():
            cls._get_storage().pop(name)
            return True
        return False

    @classmethod
    def has(cls, name: str) -> bool:
        """Check if a tagger exists in the registry."""
        return name in cls._get_storage()

    @overload
    @classmethod
    def get(cls, name: str) -> T: ...

    @overload
    @classmethod
    def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ...

    @overload
    @classmethod
    def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ...

    @classmethod
    def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]:
        """Get a tagger from the registry; raise ValueError if it doesn't exist."""

        matches = [registered for registered in cls._get_storage() if re.match(registered, name)]

        if len(matches) > 1:
            raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}")

        elif len(matches) == 0:
            if raise_on_missing:
                tagger_names = ", ".join([tn for tn, _ in cls.items()])
                raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}")
            return None

        else:
            name = matches[0]
            t, _ = cls._get_storage()[name]
            return t