_api.py 5.11 KB
Newer Older
1
2
3
import importlib
import inspect
import sys
4
from dataclasses import dataclass, fields
5
6
from inspect import signature
from typing import Any, Callable, Dict, cast
7

Philip Meier's avatar
Philip Meier committed
8
9
from torchvision._utils import StrEnum

10
from .._internally_replaced_utils import load_state_dict_from_url
11
12


13
__all__ = ["WeightsEnum", "Weights", "get_weight"]
14
15
16


@dataclass
17
class Weights:
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    """
    This class is used to group important attributes associated with the pre-trained weights.

    Args:
        url (str): The location where we find the weights.
        transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms)
            needed to use the model. The reason we attach a constructor method rather than an already constructed
            object is because the specific object might have memory and thus we want to delay initialization until
            needed.
        meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be
            informative attributes (for example the number of parameters/flops, recipe link/methods used in training
            etc), configuration parameters (for example the `num_classes`) needed to construct the model or important
            meta-data (for example the `classes` of a classification model) needed to use the model.
    """

    url: str
    transforms: Callable
    meta: Dict[str, Any]


Philip Meier's avatar
Philip Meier committed
38
class WeightsEnum(StrEnum):
39
40
41
    """
    This class is the parent class of all model weights. Each model building method receives an optional `weights`
    parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
42
    `Weights`.
43
44

    Args:
45
        value (Weights): The data class entry with the weight information.
46
47
    """

48
    def __init__(self, value: Weights):
49
50
51
52
53
54
        self._value_ = value

    @classmethod
    def verify(cls, obj: Any) -> Any:
        if obj is not None:
            if type(obj) is str:
55
                obj = cls.from_str(obj.replace(cls.__name__ + ".", ""))
56
            elif not isinstance(obj, cls):
57
                raise TypeError(
58
                    f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
59
60
61
                )
        return obj

Philip Meier's avatar
Philip Meier committed
62
    def get_state_dict(self, progress: bool) -> Dict[str, Any]:
63
64
        return load_state_dict_from_url(self.url, progress=progress)

Joao Gomes's avatar
Joao Gomes committed
65
    def __repr__(self) -> str:
66
67
68
        return f"{self.__class__.__name__}.{self._name_}"

    def __getattr__(self, name):
69
70
        # Be able to fetch Weights attributes directly
        for f in fields(Weights):
71
72
73
            if f.name == name:
                return object.__getattribute__(self.value, name)
        return super().__getattr__(name)
74
75


76
def get_weight(name: str) -> WeightsEnum:
77
    """
78
    Gets the weight enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
79
80

    Args:
81
        name (str): The name of the weight enum entry.
82
83

    Returns:
84
        WeightsEnum: The requested weight enum.
85
    """
86
87
88
89
90
91
92
93
94
95
    try:
        enum_name, value_name = name.split(".")
    except ValueError:
        raise ValueError(f"Invalid weight name provided: '{name}'.")

    base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
    base_module = importlib.import_module(base_module_name)
    model_modules = [base_module] + [
        x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py")
    ]
96

97
    weights_enum = None
98
99
100
101
102
    for m in model_modules:
        potential_class = m.__dict__.get(enum_name, None)
        if potential_class is not None and issubclass(potential_class, WeightsEnum):
            weights_enum = potential_class
            break
103

104
    if weights_enum is None:
105
        raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
106

107
    return weights_enum.from_str(value_name)
108
109


110
def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    """
    Internal method that gets the weight enum of a specific model builder method.
    Might be removed after the handle_legacy_interface is removed.

    Args:
        fn (Callable): The builder method used to create the model.
        weight_name (str): The name of the weight enum entry of the specific model.
    Returns:
        WeightsEnum: The requested weight enum.
    """
    sig = signature(fn)
    if "weights" not in sig.parameters:
        raise ValueError("The method is missing the 'weights' argument.")

    ann = signature(fn).parameters["weights"].annotation
    weights_enum = None
    if isinstance(ann, type) and issubclass(ann, WeightsEnum):
        weights_enum = ann
    else:
        # handle cases like Union[Optional, T]
        # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
        for t in ann.__args__:  # type: ignore[union-attr]
            if isinstance(t, type) and issubclass(t, WeightsEnum):
                weights_enum = t
                break

    if weights_enum is None:
        raise ValueError(
            "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
        )

    return cast(WeightsEnum, weights_enum)