Unverified Commit dd1adb07 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding multiweight support to FasterRCNN (#4847)

* Aligning exception with all other models.

* Adding prototype preprocessing on video references.

* Adding the rest of model builders on faster_rcnn.
parent 8bb6b0e2
......@@ -33,6 +33,12 @@ from engine import train_one_epoch, evaluate
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
try:
from torchvision.prototype import models as PM
except ImportError:
PM = None
def get_dataset(name, image_set, transform, data_path):
paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
p, ds_fn, num_classes = paths[name]
......@@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path):
return ds, num_classes
def get_transform(train, data_augmentation):
return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval()
def get_transform(train, args):
if train:
return presets.DetectionPresetTrain(args.data_augmentation)
elif not args.weights:
return presets.DetectionPresetEval()
else:
fn = PM.detection.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
return weights.transforms()
def get_args_parser(add_help=True):
......@@ -128,6 +141,9 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser
......@@ -143,10 +159,8 @@ def main(args):
# Data loading code
print("Loading data")
dataset, num_classes = get_dataset(
args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path
)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path)
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
print("Creating data loaders")
if args.distributed:
......@@ -175,9 +189,14 @@ def main(args):
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
if not args.weights:
model = torchvision.models.detection.__dict__[args.model](
num_classes=num_classes, pretrained=args.pretrained, **kwargs
pretrained=args.pretrained, num_classes=num_classes, **kwargs
)
else:
if PM is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......
......@@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev):
TM.test_classification_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection))
@pytest.mark.parametrize("dev", cpu_and_gpu())
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_detection_model(model_fn, dev):
TM.test_detection_model(model_fn, dev)
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
def test_quantized_classification_model(model_fn):
......@@ -71,6 +78,7 @@ def test_video_model(model_fn, dev):
@pytest.mark.parametrize(
"model_fn, module_name",
get_models_with_module_names(models)
+ get_models_with_module_names(models.detection)
+ get_models_with_module_names(models.quantization)
+ get_models_with_module_names(models.segmentation)
+ get_models_with_module_names(models.video),
......@@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
"models": {
"input_shape": (1, 3, 224, 224),
},
"detection": {
"input_shape": (3, 300, 300),
},
"quantization": {
"input_shape": (1, 3, 224, 224),
},
......@@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
model_name = model_fn.__name__
kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})}
input_shape = kwargs.pop("input_shape")
kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models
x = torch.rand(input_shape).to(device=dev)
if module_name == "detection":
x = [x]
# compare with new model builder parameterized in the old fashion way
model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev)
......
......@@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa
if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError(f"pretrained {arch} is not supported as of now")
raise ValueError(f"No checkpoint is available for model type {arch}")
else:
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
......
import warnings
from typing import Any, Optional
from typing import Any, Optional, Union
from ....models.detection.faster_rcnn import (
_validate_trainable_layers,
_mobilenet_extractor,
_resnet_fpn_extractor,
_validate_trainable_layers,
AnchorGenerator,
FasterRCNN,
misc_nn_ops,
overwrite_eps,
......@@ -11,10 +13,22 @@ from ....models.detection.faster_rcnn import (
from ...transforms.presets import CocoEval
from .._api import Weights, WeightEntry
from .._meta import _COCO_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large
from ..resnet import ResNet50Weights, resnet50
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
__all__ = [
"FasterRCNN",
"FasterRCNNResNet50FPNWeights",
"FasterRCNNMobileNetV3LargeFPNWeights",
"FasterRCNNMobileNetV3Large320FPNWeights",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
]
_common_meta = {"categories": _COCO_CATEGORIES}
class FasterRCNNResNet50FPNWeights(Weights):
......@@ -22,13 +36,37 @@ class FasterRCNNResNet50FPNWeights(Weights):
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
"map": 37.0,
},
)
class FasterRCNNMobileNetV3LargeFPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn",
"map": 32.8,
},
)
class FasterRCNNMobileNetV3Large320FPNWeights(Weights):
Coco_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
meta={
**_common_meta,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn",
"map": 22.8,
},
)
def fasterrcnn_resnet50_fpn(
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
weights_backbone: Optional[ResNet50Weights] = None,
......@@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn(
overwrite_eps(model, 0.0)
return model
def _fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if weights is not None:
weights_backbone = None
num_classes = len(weights.meta["categories"])
trainable_backbone_layers = _validate_trainable_layers(
weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3
)
backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
anchor_sizes = (
(
32,
64,
128,
256,
512,
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))
return model
def fasterrcnn_mobilenet_v3_large_fpn(
weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = {
"rpn_score_thresh": 0.05,
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
trainable_backbone_layers,
**kwargs,
)
def fasterrcnn_mobilenet_v3_large_320_fpn(
weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None,
weights_backbone: Optional[MobileNetV3LargeWeights] = None,
progress: bool = True,
num_classes: int = 91,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
) -> FasterRCNN:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights)
if "pretrained_backbone" in kwargs:
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone)
defaults = {
"min_size": 320,
"max_size": 640,
"rpn_pre_nms_top_n_test": 150,
"rpn_post_nms_top_n_test": 150,
"rpn_score_thresh": 0.05,
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(
weights,
weights_backbone,
progress,
num_classes,
trainable_backbone_layers,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment