Unverified Commit d59398b5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Detection recipe enhancements (#5715)

* Detection recipe enhancements

* Add back nesterov momentum
parent ec1c2a12
...@@ -230,7 +230,7 @@ def main(args): ...@@ -230,7 +230,7 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if args.norm_weight_decay is None: if args.norm_weight_decay is None:
parameters = model.parameters() parameters = [p for p in model.parameters() if p.requires_grad]
else: else:
param_groups = torchvision.ops._utils.split_normalization_params(model) param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay] wd_groups = [args.norm_weight_decay, args.weight_decay]
......
...@@ -3,7 +3,7 @@ import transforms as T ...@@ -3,7 +3,7 @@ import transforms as T
class DetectionPresetTrain: class DetectionPresetTrain:
def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
if data_augmentation == "hflip": if data_augmentation == "hflip":
self.transforms = T.Compose( self.transforms = T.Compose(
[ [
...@@ -12,6 +12,27 @@ class DetectionPresetTrain: ...@@ -12,6 +12,27 @@ class DetectionPresetTrain:
T.ConvertImageDtype(torch.float), T.ConvertImageDtype(torch.float),
] ]
) )
elif data_augmentation == "lsj":
self.transforms = T.Compose(
[
T.ScaleJitter(target_size=(1024, 1024)),
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "multiscale":
self.transforms = T.Compose(
[
T.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
T.RandomHorizontalFlip(p=hflip_prob),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
]
)
elif data_augmentation == "ssd": elif data_augmentation == "ssd":
self.transforms = T.Compose( self.transforms = T.Compose(
[ [
......
...@@ -68,6 +68,7 @@ def get_args_parser(add_help=True): ...@@ -68,6 +68,7 @@ def get_args_parser(add_help=True):
parser.add_argument( parser.add_argument(
"-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
) )
parser.add_argument("--opt", default="sgd", type=str, help="optimizer")
parser.add_argument( parser.add_argument(
"--lr", "--lr",
default=0.02, default=0.02,
...@@ -84,6 +85,12 @@ def get_args_parser(add_help=True): ...@@ -84,6 +85,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)", help="weight decay (default: 1e-4)",
dest="weight_decay", dest="weight_decay",
) )
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument( parser.add_argument(
"--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)" "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
) )
...@@ -176,6 +183,8 @@ def main(args): ...@@ -176,6 +183,8 @@ def main(args):
print("Creating model") print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if args.data_augmentation in ["multiscale", "lsj"]:
kwargs["_skip_resize"] = True
if "rcnn" in args.model: if "rcnn" in args.model:
if args.rpn_score_thresh is not None: if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh kwargs["rpn_score_thresh"] = args.rpn_score_thresh
...@@ -191,8 +200,26 @@ def main(args): ...@@ -191,8 +200,26 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module model_without_ddp = model.module
params = [p for p in model.parameters() if p.requires_grad] if args.norm_weight_decay is None:
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None scaler = torch.cuda.amp.GradScaler() if args.amp else None
......
...@@ -64,7 +64,6 @@ def test_get_weight(name, weight): ...@@ -64,7 +64,6 @@ def test_get_weight(name, weight):
) )
def test_naming_conventions(model_fn): def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn) weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT") assert len(weights_enum) == 0 or hasattr(weights_enum, "DEFAULT")
......
...@@ -187,6 +187,7 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -187,6 +187,7 @@ class FasterRCNN(GeneralizedRCNN):
box_batch_size_per_image=512, box_batch_size_per_image=512,
box_positive_fraction=0.25, box_positive_fraction=0.25,
bbox_reg_weights=None, bbox_reg_weights=None,
**kwargs,
): ):
if not hasattr(backbone, "out_channels"): if not hasattr(backbone, "out_channels"):
...@@ -268,7 +269,7 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -268,7 +269,7 @@ class FasterRCNN(GeneralizedRCNN):
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
if image_std is None: if image_std is None:
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
super().__init__(backbone, rpn, roi_heads, transform) super().__init__(backbone, rpn, roi_heads, transform)
......
...@@ -373,6 +373,7 @@ class FCOS(nn.Module): ...@@ -373,6 +373,7 @@ class FCOS(nn.Module):
nms_thresh: float = 0.6, nms_thresh: float = 0.6,
detections_per_img: int = 100, detections_per_img: int = 100,
topk_candidates: int = 1000, topk_candidates: int = 1000,
**kwargs,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -410,7 +411,7 @@ class FCOS(nn.Module): ...@@ -410,7 +411,7 @@ class FCOS(nn.Module):
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
if image_std is None: if image_std is None:
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
self.center_sampling_radius = center_sampling_radius self.center_sampling_radius = center_sampling_radius
self.score_thresh = score_thresh self.score_thresh = score_thresh
......
...@@ -198,6 +198,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -198,6 +198,7 @@ class KeypointRCNN(FasterRCNN):
keypoint_head=None, keypoint_head=None,
keypoint_predictor=None, keypoint_predictor=None,
num_keypoints=None, num_keypoints=None,
**kwargs,
): ):
if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))): if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
...@@ -259,6 +260,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -259,6 +260,7 @@ class KeypointRCNN(FasterRCNN):
box_batch_size_per_image, box_batch_size_per_image,
box_positive_fraction, box_positive_fraction,
bbox_reg_weights, bbox_reg_weights,
**kwargs,
) )
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
......
...@@ -195,6 +195,7 @@ class MaskRCNN(FasterRCNN): ...@@ -195,6 +195,7 @@ class MaskRCNN(FasterRCNN):
mask_roi_pool=None, mask_roi_pool=None,
mask_head=None, mask_head=None,
mask_predictor=None, mask_predictor=None,
**kwargs,
): ):
if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))): if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))):
...@@ -254,6 +255,7 @@ class MaskRCNN(FasterRCNN): ...@@ -254,6 +255,7 @@ class MaskRCNN(FasterRCNN):
box_batch_size_per_image, box_batch_size_per_image,
box_positive_fraction, box_positive_fraction,
bbox_reg_weights, bbox_reg_weights,
**kwargs,
) )
self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_roi_pool = mask_roi_pool
......
...@@ -342,6 +342,7 @@ class RetinaNet(nn.Module): ...@@ -342,6 +342,7 @@ class RetinaNet(nn.Module):
fg_iou_thresh=0.5, fg_iou_thresh=0.5,
bg_iou_thresh=0.4, bg_iou_thresh=0.4,
topk_candidates=1000, topk_candidates=1000,
**kwargs,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -383,7 +384,7 @@ class RetinaNet(nn.Module): ...@@ -383,7 +384,7 @@ class RetinaNet(nn.Module):
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
if image_std is None: if image_std is None:
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.nms_thresh = nms_thresh self.nms_thresh = nms_thresh
......
...@@ -195,6 +195,7 @@ class SSD(nn.Module): ...@@ -195,6 +195,7 @@ class SSD(nn.Module):
iou_thresh: float = 0.5, iou_thresh: float = 0.5,
topk_candidates: int = 400, topk_candidates: int = 400,
positive_fraction: float = 0.25, positive_fraction: float = 0.25,
**kwargs: Any,
): ):
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -227,7 +228,7 @@ class SSD(nn.Module): ...@@ -227,7 +228,7 @@ class SSD(nn.Module):
if image_std is None: if image_std is None:
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform( self.transform = GeneralizedRCNNTransform(
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size, **kwargs
) )
self.score_thresh = score_thresh self.score_thresh = score_thresh
......
import math import math
from typing import List, Tuple, Dict, Optional from typing import List, Tuple, Dict, Optional, Any
import torch import torch
import torchvision import torchvision
...@@ -91,6 +91,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -91,6 +91,7 @@ class GeneralizedRCNNTransform(nn.Module):
image_std: List[float], image_std: List[float],
size_divisible: int = 32, size_divisible: int = 32,
fixed_size: Optional[Tuple[int, int]] = None, fixed_size: Optional[Tuple[int, int]] = None,
**kwargs: Any,
): ):
super().__init__() super().__init__()
if not isinstance(min_size, (list, tuple)): if not isinstance(min_size, (list, tuple)):
...@@ -101,6 +102,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -101,6 +102,7 @@ class GeneralizedRCNNTransform(nn.Module):
self.image_std = image_std self.image_std = image_std
self.size_divisible = size_divisible self.size_divisible = size_divisible
self.fixed_size = fixed_size self.fixed_size = fixed_size
self._skip_resize = kwargs.pop("_skip_resize", False)
def forward( def forward(
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
...@@ -170,6 +172,8 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -170,6 +172,8 @@ class GeneralizedRCNNTransform(nn.Module):
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
h, w = image.shape[-2:] h, w = image.shape[-2:]
if self.training: if self.training:
if self._skip_resize:
return image, target
size = float(self.torch_choice(self.min_size)) size = float(self.torch_choice(self.min_size))
else: else:
# FIXME assume for now that testing uses the largest scale # FIXME assume for now that testing uses the largest scale
......
...@@ -43,7 +43,13 @@ def split_normalization_params( ...@@ -43,7 +43,13 @@ def split_normalization_params(
) -> Tuple[List[Tensor], List[Tensor]]: ) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes: if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm] norm_classes = [
nn.modules.batchnorm._BatchNorm,
nn.LayerNorm,
nn.GroupNorm,
nn.modules.instancenorm._InstanceNorm,
nn.LocalResponseNorm,
]
for t in norm_classes: for t in norm_classes:
if not issubclass(t, nn.Module): if not issubclass(t, nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment