"vscode:/vscode.git/clone" did not exist on "dc87f526d413ff1763c5a4a6021fa93baa117694"
Unverified Commit 4176556e authored by YosuaMichael's avatar YosuaMichael Committed by GitHub
Browse files

Add weight for mnasnet0_75 and mnasnet1_3 (#6019)

* Add weight for mnasnet0_75 and mnasnet1_3

* Fix missing comma

* Add PR url as recipe, and update the metrics

* Add weights to legacy handler

* Update docs to specify there are weights available
parent 9e788719
...@@ -235,8 +235,20 @@ class MNASNet0_5_Weights(WeightsEnum): ...@@ -235,8 +235,20 @@ class MNASNet0_5_Weights(WeightsEnum):
class MNASNet0_75_Weights(WeightsEnum): class MNASNet0_75_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet0_75 IMAGENET1K_V1 = Weights(
pass url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/6019",
"num_params": 3170208,
"metrics": {
"acc@1": 71.180,
"acc@5": 90.496,
},
},
)
DEFAULT = IMAGENET1K_V1
class MNASNet1_0_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum):
...@@ -256,8 +268,20 @@ class MNASNet1_0_Weights(WeightsEnum): ...@@ -256,8 +268,20 @@ class MNASNet1_0_Weights(WeightsEnum):
class MNASNet1_3_Weights(WeightsEnum): class MNASNet1_3_Weights(WeightsEnum):
# If a default model is added here the corresponding changes need to be done in mnasnet1_3 IMAGENET1K_V1 = Weights(
pass url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/pull/6019",
"num_params": 6282256,
"metrics": {
"acc@1": 76.506,
"acc@5": 93.522,
},
},
)
DEFAULT = IMAGENET1K_V1
def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
...@@ -299,15 +323,17 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = ...@@ -299,15 +323,17 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool =
return _mnasnet(0.5, weights, progress, **kwargs) return _mnasnet(0.5, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None)) @handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 0.75 from """MNASNet with depth multiplier of 0.75 from
`MnasNet: Platform-Aware Neural Architecture Search for Mobile `MnasNet: Platform-Aware Neural Architecture Search for Mobile
<https://arxiv.org/pdf/1807.11626.pdf>`_ paper. <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
Args: Args:
weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): Currently weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
no pre-trained weights are available and by default no pre-trained pretrained weights to use. See
:class:`~torchvision.models.MNASNet0_75_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used. weights are used.
progress (bool, optional): If True, displays a progress bar of the progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True. download to stderr. Default is True.
...@@ -351,15 +377,17 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = ...@@ -351,15 +377,17 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool =
return _mnasnet(1.0, weights, progress, **kwargs) return _mnasnet(1.0, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", None)) @handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
"""MNASNet with depth multiplier of 1.3 from """MNASNet with depth multiplier of 1.3 from
`MnasNet: Platform-Aware Neural Architecture Search for Mobile `MnasNet: Platform-Aware Neural Architecture Search for Mobile
<https://arxiv.org/pdf/1807.11626.pdf>`_ paper. <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
Args: Args:
weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): Currently weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
no pre-trained weights are available and by default no pre-trained pretrained weights to use. See
:class:`~torchvision.models.MNASNet1_3_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used. weights are used.
progress (bool, optional): If True, displays a progress bar of the progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True. download to stderr. Default is True.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment