Unverified Commit dd1adb07 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to FasterRCNN (#4847)

* Aligning exception with all other models.

* Adding prototype preprocessing on video references.

* Adding the rest of model builders on faster_rcnn.
parent 8bb6b0e2
...@@ -33,6 +33,12 @@ from engine import train_one_epoch, evaluate ...@@ -33,6 +33,12 @@ from engine import train_one_epoch, evaluate
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def get_dataset(name, image_set, transform, data_path): def get_dataset(name, image_set, transform, data_path):
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[name] p, ds_fn, num_classes = paths[name]
...@@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path): ...@@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path):
return ds, num_classes return ds, num_classes
def get_transform(train, data_augmentation): def get_transform(train, args):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval() if train:
return presets.DetectionPresetTrain(args.data_augmentation)
elif not args.weights:
return presets.DetectionPresetEval()
else:
fn = PM.detection.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms()
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
...@@ -128,6 +141,9 @@ def get_args_parser(add_help=True): ...@@ -128,6 +141,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
...@@ -143,10 +159,8 @@ def main(args): ...@@ -143,10 +159,8 @@ def main(args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
dataset, num_classes = get_dataset( dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
...@@ -175,9 +189,14 @@ def main(args): ...@@ -175,9 +189,14 @@ def main(args):
if "rcnn" in args.model: if "rcnn" in args.model:
if args.rpn_score_thresh is not None: if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh kwargs["rpn_score_thresh"] = args.rpn_score_thresh
if not args.weights:
model = torchvision.models.detection.__dict__[args.model]( model = torchvision.models.detection.__dict__[args.model](
num_classes=num_classes, pretrained=args.pretrained, **kwargs pretrained=args.pretrained, num_classes=num_classes, **kwargs
) )
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device) model.to(device)
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......
...@@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev): ...@@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev) TM.test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_detection_model(model_fn, dev):
TM.test_detection_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization)) @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_quantized_classification_model(model_fn): def test_quantized_classification_model(model_fn):
...@@ -71,6 +78,7 @@ def test_video_model(model_fn, dev): ...@@ -71,6 +78,7 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_fn, module_name", "model_fn, module_name",
get_models_with_module_names(models) get_models_with_module_names(models)
+ get_models_with_module_names(models.detection)
+ get_models_with_module_names(models.quantization) + get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation) + get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video), + get_models_with_module_names(models.video),
...@@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev): ...@@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
"models": { "models": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
}, },
"detection": {
"input_shape": (3, 300, 300),
},
"quantization": { "quantization": {
"input_shape": (1, 3, 224, 224), "input_shape": (1, 3, 224, 224),
}, },
...@@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev): ...@@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
model_name = model_fn.__name__ model_name = model_fn.__name__
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape") input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
x = torch.rand(input_shape).to(device=dev) x = torch.rand(input_shape).to(device=dev)
if module_name == "detection":
x = [x]
# compare with new model builder parameterized in the old fashion way # compare with new model builder parameterized in the old fashion way
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
......
...@@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa ...@@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa
if pretrained: if pretrained:
model_url = model_urls[arch] model_url = model_urls[arch]
if model_url is None: if model_url is None:
raise NotImplementedError(f"pretrained {arch} is not supported as of now") raise ValueError(f"No checkpoint is available for model type {arch}")
else: else:
state_dict = load_state_dict_from_url(model_url, progress=progress) state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional, Union
from ....models.detection.faster_rcnn import ( from ....models.detection.faster_rcnn import (
_validate_trainable_layers, _mobilenet_extractor,
_resnet_fpn_extractor, _resnet_fpn_extractor,
_validate_trainable_layers,
AnchorGenerator,
FasterRCNN, FasterRCNN,
misc_nn_ops, misc_nn_ops,
overwrite_eps, overwrite_eps,
...@@ -11,10 +13,22 @@ from ....models.detection.faster_rcnn import ( ...@@ -11,10 +13,22 @@ from ....models.detection.faster_rcnn import (
from ...transforms.presets import CocoEval from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50 from ..resnet import ResNet50Weights, resnet50
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] __all__ = [
"FasterRCNN",
"FasterRCNNResNet50FPNWeights",
"FasterRCNNMobileNetV3LargeFPNWeights",
"FasterRCNNMobileNetV3Large320FPNWeights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]
_common_meta = {"categories": _COCO_CATEGORIES}
class FasterRCNNResNet50FPNWeights(Weights): class FasterRCNNResNet50FPNWeights(Weights):
...@@ -22,13 +36,37 @@ class FasterRCNNResNet50FPNWeights(Weights): ...@@ -22,13 +36,37 @@ class FasterRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval, transforms=CocoEval,
meta={ meta={
"categories": _COCO_CATEGORIES, **_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0, "map": 37.0,
}, },
) )
class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
)
class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
)
def fasterrcnn_resnet50_fpn( def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None, weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None, weights_backbone: Optional[ResNet50Weights] = None,
...@@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn( ...@@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn(
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
)
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = (
(
32,
64,
128,
256,
512,
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = {
"rpn_score_thresh": 0.05,
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
trainable_backbone_layers,
**kwargs,
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = {
"min_size": 320,
"max_size": 640,
"rpn_pre_nms_top_n_test": 150,
"rpn_post_nms_top_n_test": 150,
"rpn_score_thresh": 0.05,
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
trainable_backbone_layers,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment