Unverified Commit 729178c7 authored by Piyush Singh's avatar Piyush Singh Committed by GitHub
Browse files

Replace asserts with ValueErrors (#5275)



* replace assert with valueerror

* pytest should raise ValueError not AssertionError

* minor edit

* raise assert changed to raise valueerror in test

* Update torchvision/models/detection/backbone_utils.py
Co-authored-by: default avatarAditya Oke <47158509+oke-aditya@users.noreply.github.com>

* Update torchvision/models/detection/backbone_utils.py
Co-authored-by: default avatarAditya Oke <47158509+oke-aditya@users.noreply.github.com>

* minor edits

* minor edits

* added one test

* added another test

* added another test

* test for mobilenet

* ufmt formatting

* cant have unused variables

* suggested changes

* minor edit

* corrected bug pointed out by datumbox

* corrected bug pointed out by datumbox

* bug correction and shorten msg

* ufmt stuff

* resolved last comment
Co-authored-by: default avatarAbhijit Deo <72816663+abhi-glitchhg@users.noreply.github.com>
Co-authored-by: default avatarAditya Oke <47158509+oke-aditya@users.noreply.github.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent fccc5867
...@@ -7,9 +7,8 @@ import torch ...@@ -7,9 +7,8 @@ import torch
from common_utils import set_rng_seed from common_utils import set_rng_seed
from torchvision import models from torchvision import models
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torchvision.models.feature_extraction import get_graph_node_names
def get_available_models(): def get_available_models():
...@@ -23,6 +22,23 @@ def test_resnet_fpn_backbone(backbone_name): ...@@ -23,6 +22,23 @@ def test_resnet_fpn_backbone(backbone_name):
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ["0", "1", "2", "3", "pool"] assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5])
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
def test_mobilenet_backbone(backbone_name):
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1)
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function # Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x): def leaf_function(x):
......
...@@ -3,8 +3,7 @@ import copy ...@@ -3,8 +3,7 @@ import copy
import pytest import pytest
import torch import torch
from common_utils import assert_equal from common_utils import assert_equal
from torchvision.models.detection import _utils from torchvision.models.detection import _utils, backbone_utils
from torchvision.models.detection import backbone_utils
from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection.transform import GeneralizedRCNNTransform
...@@ -54,7 +53,7 @@ class TestModelsDetectionUtils: ...@@ -54,7 +53,7 @@ class TestModelsDetectionUtils:
) )
assert ret == 3 assert ret == 3
# can't go beyond 5 # can't go beyond 5
with pytest.raises(AssertionError): with pytest.raises(ValueError, match=r"Trainable backbone layers should be in the range"):
ret = backbone_utils._validate_trainable_layers( ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3 pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3
) )
......
import warnings import warnings
from typing import Callable, Dict, Optional, List, Union from typing import Callable, Dict, List, Optional, Union
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.ops import misc as misc_nn_ops from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock from torchvision.ops.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool
from .. import mobilenet from .. import mobilenet, resnet
from .. import resnet
from .._utils import IntermediateLayerGetter from .._utils import IntermediateLayerGetter
...@@ -111,7 +110,8 @@ def _resnet_fpn_extractor( ...@@ -111,7 +110,8 @@ def _resnet_fpn_extractor(
) -> BackboneWithFPN: ) -> BackboneWithFPN:
# select layers that wont be frozen # select layers that wont be frozen
assert 0 <= trainable_layers <= 5 if trainable_layers < 0 or trainable_layers > 5:
raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
if trainable_layers == 5: if trainable_layers == 5:
layers_to_train.append("bn1") layers_to_train.append("bn1")
...@@ -124,7 +124,8 @@ def _resnet_fpn_extractor( ...@@ -124,7 +124,8 @@ def _resnet_fpn_extractor(
if returned_layers is None: if returned_layers is None:
returned_layers = [1, 2, 3, 4] returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5 if min(returned_layers) <= 0 or max(returned_layers) >= 5:
raise ValueError(f"Each returned layer should be in the range [1,4]. Got {returned_layers}")
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
in_channels_stage2 = backbone.inplanes // 8 in_channels_stage2 = backbone.inplanes // 8
...@@ -152,7 +153,10 @@ def _validate_trainable_layers( ...@@ -152,7 +153,10 @@ def _validate_trainable_layers(
# by default freeze first blocks # by default freeze first blocks
if trainable_backbone_layers is None: if trainable_backbone_layers is None:
trainable_backbone_layers = default_value trainable_backbone_layers = default_value
assert 0 <= trainable_backbone_layers <= max_value if trainable_backbone_layers < 0 or trainable_backbone_layers > max_value:
raise ValueError(
f"Trainable backbone layers should be in the range [0,{max_value}], got {trainable_backbone_layers} "
)
return trainable_backbone_layers return trainable_backbone_layers
...@@ -172,7 +176,7 @@ def mobilenet_backbone( ...@@ -172,7 +176,7 @@ def mobilenet_backbone(
def _mobilenet_extractor( def _mobilenet_extractor(
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
fpn: bool, fpn: bool,
trainable_layers, trainable_layers: int,
returned_layers: Optional[List[int]] = None, returned_layers: Optional[List[int]] = None,
extra_blocks: Optional[ExtraFPNBlock] = None, extra_blocks: Optional[ExtraFPNBlock] = None,
) -> nn.Module: ) -> nn.Module:
...@@ -183,7 +187,8 @@ def _mobilenet_extractor( ...@@ -183,7 +187,8 @@ def _mobilenet_extractor(
num_stages = len(stage_indices) num_stages = len(stage_indices)
# find the index of the layer from which we wont freeze # find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages if trainable_layers < 0 or trainable_layers > num_stages:
raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
for b in backbone[:freeze_before]: for b in backbone[:freeze_before]:
...@@ -197,7 +202,8 @@ def _mobilenet_extractor( ...@@ -197,7 +202,8 @@ def _mobilenet_extractor(
if returned_layers is None: if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1] returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages if min(returned_layers) < 0 or max(returned_layers) >= num_stages:
raise ValueError(f"Each returned layer should be in the range [0,{num_stages - 1}], got {returned_layers} ")
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment