registry.py 2.75 KB
Newer Older
1
2
#!/usr/bin/env python
from dataclasses import dataclass
Frank Lee's avatar
Frank Lee committed
3
from typing import Callable, List, Union
4

5
__all__ = ["ModelZooRegistry", "ModelAttribute", "model_zoo"]
6
7
8
9
10
11


@dataclass
class ModelAttribute:
    """
    Attributes of a model.
12
13
14
15

    Args:
        has_control_flow (bool): Whether the model contains branching in its forward method.
        has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
16
    """
17

18
    has_control_flow: bool = False
19
    has_stochastic_depth_prob: bool = False
20
21
22
23
24
25
26


class ModelZooRegistry(dict):
    """
    A registry to map model names to model and data generation functions.
    """

27
28
29
30
31
32
33
34
35
    def register(
        self,
        name: str,
        model_fn: Callable,
        data_gen_fn: Callable,
        output_transform_fn: Callable,
        loss_fn: Callable = None,
        model_attribute: ModelAttribute = None,
    ):
36
37
38
39
        """
        Register a model and data generation function.

        Examples:
40
41
42
43
44
45
46
47
48
49
50
51
52

        ```python
        # normal forward workflow
        model = resnet18()
        data = resnet18_data_gen()
        output = model(**data)
        transformed_output = output_transform_fn(output)
        loss = loss_fn(transformed_output)

        # Register
        model_zoo = ModelZooRegistry()
        model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
        ```
53
54
55

        Args:
            name (str): Name of the model.
56
57
58
59
            model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
            data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
            output_transform_fn (Callable): A function that transforms the output of the model into Dict.
            loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
60
61
            model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
        """
62
        self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
63

Frank Lee's avatar
Frank Lee committed
64
    def get_sub_registry(self, keyword: Union[str, List[str]]):
65
66
67
68
69
70
71
72
        """
        Get a sub registry with models that contain the keyword.

        Args:
            keyword (str): Keyword to filter models.
        """
        new_dict = dict()

Frank Lee's avatar
Frank Lee committed
73
74
75
76
77
78
        if isinstance(keyword, str):
            keyword_list = [keyword]
        else:
            keyword_list = keyword
        assert isinstance(keyword_list, (list, tuple))

79
        for k, v in self.items():
Frank Lee's avatar
Frank Lee committed
80
81
            for kw in keyword_list:
                if kw in k:
82
                    new_dict[k] = v
83

84
        assert len(new_dict) > 0, f"No model found with keyword {keyword}"
85
86
87
88
        return new_dict


model_zoo = ModelZooRegistry()