"vscode:/vscode.git/clone" did not exist on "216035315185edec747dca8879d7197e7fb22c7d"
Unverified Commit 0a919dbb authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add registration mechanism for models (#6333)

* Model registration mechanism.

* Add overwrite options to the dataset prototype registration mechanism.

* Adding example models.

* Fix module filtering

* Fix linter

* Fix docs

* Make name optional if same as model builder

* Apply updates from code-review.

* fix minor bug

* Adding getter for model weight enum

* Support both strings and callables on get_model_weight.

* linter fixes

* Fixing mypy.

* Renaming `get_model_weight` to `get_model_weights`

* Registering all classification models.

* Registering all video models.

* Registering all detection models.

* Registering all optical flow models.

* Fixing mypy.

* Registering all segmentation models.

* Registering all quantization models.

* Fixing linter

* Registering all prototype depth perception models.

* Adding tests and updating existing tests.

* Fix linters

* Fix tests.

* Add beta annotation on docs.

* Fix tests.

* Apply changes from code-review.

* Adding documentation.

* Fix docs.
parent 63870514
......@@ -10,7 +10,7 @@ import torch.nn as nn
from ...ops import MLP, StochasticDepth
from ...transforms._presets import VideoClassification
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _ovewrite_named_param
......@@ -461,6 +461,7 @@ class MViT_V1_B_Weights(WeightsEnum):
DEFAULT = KINETICS400_V1
@register_model()
def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
"""
Constructs a base MViTV1 architecture from
......
......@@ -6,7 +6,7 @@ from torch import Tensor
from ...transforms._presets import VideoClassification
from ...utils import _log_api_usage_once
from .._api import Weights, WeightsEnum
from .._api import register_model, Weights, WeightsEnum
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -373,6 +373,7 @@ class R2Plus1D_18_Weights(WeightsEnum):
DEFAULT = KINETICS400_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer Resnet3D model.
......@@ -409,6 +410,7 @@ def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, *
)
@register_model()
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer Mixed Convolution network as in
......@@ -445,6 +447,7 @@ def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, *
)
@register_model()
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
"""Construct 18 layer deep R(2+1)D network as in
......
......@@ -9,7 +9,7 @@ import torch.nn as nn
from ..ops.misc import Conv2dNormActivation, MLP
from ..transforms._presets import ImageClassification, InterpolationMode
from ..utils import _log_api_usage_once
from ._api import Weights, WeightsEnum
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
......@@ -596,6 +596,7 @@ class ViT_H_14_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_SWAG_E2E_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", ViT_B_16_Weights.IMAGENET1K_V1))
def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
......@@ -629,6 +630,7 @@ def vit_b_16(*, weights: Optional[ViT_B_16_Weights] = None, progress: bool = Tru
)
@register_model()
@handle_legacy_interface(weights=("pretrained", ViT_B_32_Weights.IMAGENET1K_V1))
def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
......@@ -662,6 +664,7 @@ def vit_b_32(*, weights: Optional[ViT_B_32_Weights] = None, progress: bool = Tru
)
@register_model()
@handle_legacy_interface(weights=("pretrained", ViT_L_16_Weights.IMAGENET1K_V1))
def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
......@@ -695,6 +698,7 @@ def vit_l_16(*, weights: Optional[ViT_L_16_Weights] = None, progress: bool = Tru
)
@register_model()
@handle_legacy_interface(weights=("pretrained", ViT_L_32_Weights.IMAGENET1K_V1))
def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
......@@ -728,6 +732,7 @@ def vit_l_32(*, weights: Optional[ViT_L_32_Weights] = None, progress: bool = Tru
)
@register_model()
def vit_h_14(*, weights: Optional[ViT_H_14_Weights] = None, progress: bool = True, **kwargs: Any) -> VisionTransformer:
"""
Constructs a vit_h_14 architecture from
......
......@@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.optical_flow.raft as raft
from torch import Tensor
from torchvision.models._api import WeightsEnum
from torchvision.models._api import register_model, WeightsEnum
from torchvision.models.optical_flow._utils import grid_sample, make_coords_grid, upsample_flow
from torchvision.models.optical_flow.raft import FlowHead, MotionEncoder, ResidualBlock
from torchvision.ops import Conv2dNormActivation
......@@ -617,6 +617,7 @@ class Raft_Stereo_Base_Weights(WeightsEnum):
pass
@register_model()
def raft_stereo_realtime(
*, weights: Optional[Raft_Stereo_Realtime_Weights] = None, progress=True, **kwargs
) -> RaftStereo:
......@@ -676,6 +677,7 @@ def raft_stereo_realtime(
)
@register_model()
def raft_stereo_base(*, weights: Optional[Raft_Stereo_Base_Weights] = None, progress=True, **kwargs) -> RaftStereo:
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment