Unverified Commit 460b37d2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Multi-weight support for LRASPP prototype segmentation models (#4750)

* Adding multi-weight support to LRASPP

* Adding tests for segmentation models.

* Skip segmentation test by default.
parent a078e6fb
...@@ -38,10 +38,17 @@ def test_classification_model(model_fn, dev): ...@@ -38,10 +38,17 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev) TM.test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_segmentation_model(model_fn, dev):
TM.test_segmentation_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models) + TM.get_models_from_module(models.segmentation))
@pytest.mark.parametrize("dev", cpu_and_gpu()) @pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_old_vs_new_classification_factory(model_fn, dev): def test_old_vs_new_factory(model_fn, dev):
defaults = { defaults = {
"pretrained": True, "pretrained": True,
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
......
from .fcn import * from .fcn import *
from .lraspp import *
import warnings
from functools import partial
from typing import Any, Optional
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
from .._meta import _VOC_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
__all__ = ["LRASPP", "LRASPPMobileNetV3LargeWeights", "lraspp_mobilenet_v3_large"]
class LRASPPMobileNetV3LargeWeights(Weights):
CocoWithVocLabels_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9,
"acc": 91.2,
},
)
def lraspp_mobilenet_v3_large(
weights: Optional[LRASPPMobileNetV3LargeWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 21,
**kwargs: Any,
) -> LRASPP:
if kwargs.pop("aux_loss", False):
raise NotImplementedError("This model does not use auxiliary loss")
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = LRASPPMobileNetV3LargeWeights.CocoWithVocLabels_RefV1 if kwargs.pop("pretrained") else None
weights = LRASPPMobileNetV3LargeWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
backbone = mobilenet_v3_large(weights=weights_backbone, dilated=True)
model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment