Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import torch
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN
from ._utils import overwrite_eps
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .faster_rcnn import FasterRCNN
__all__ = [
"KeypointRCNN", "keypointrcnn_resnet50_fpn"
]
__all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"]
class KeypointRCNN(FasterRCNN):
......@@ -151,27 +147,47 @@ class KeypointRCNN(FasterRCNN):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=None, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None,
# keypoint parameters
keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None,
num_keypoints=17):
def __init__(
self,
backbone,
num_classes=None,
# transform parameters
min_size=None,
max_size=1333,
image_mean=None,
image_std=None,
# RPN parameters
rpn_anchor_generator=None,
rpn_head=None,
rpn_pre_nms_top_n_train=2000,
rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000,
rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7,
rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
# keypoint parameters
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
num_keypoints=17,
):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
if min_size is None:
......@@ -184,10 +200,7 @@ class KeypointRCNN(FasterRCNN):
out_channels = backbone.out_channels
if keypoint_roi_pool is None:
keypoint_roi_pool = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=14,
sampling_ratio=2)
keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
if keypoint_head is None:
keypoint_layers = tuple(512 for _ in range(8))
......@@ -198,24 +211,39 @@ class KeypointRCNN(FasterRCNN):
keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
super(KeypointRCNN, self).__init__(
backbone, num_classes,
backbone,
num_classes,
# transform parameters
min_size, max_size,
image_mean, image_std,
min_size,
max_size,
image_mean,
image_std,
# RPN-specific parameters
rpn_anchor_generator, rpn_head,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
rpn_anchor_generator,
rpn_head,
rpn_pre_nms_top_n_train,
rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train,
rpn_post_nms_top_n_test,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_fg_iou_thresh,
rpn_bg_iou_thresh,
rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_score_thresh,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights)
box_roi_pool,
box_head,
box_predictor,
box_score_thresh,
box_nms_thresh,
box_detections_per_img,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
)
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
self.roi_heads.keypoint_head = keypoint_head
......@@ -249,9 +277,7 @@ class KeypointRCNNPredictor(nn.Module):
stride=2,
padding=deconv_kernel // 2 - 1,
)
nn.init.kaiming_normal_(
self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
)
nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
nn.init.constant_(self.kps_score_lowres.bias, 0)
self.up_scale = 2
self.out_channels = num_keypoints
......@@ -265,16 +291,20 @@ class KeypointRCNNPredictor(nn.Module):
model_urls = {
# legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
'keypointrcnn_resnet50_fpn_coco_legacy':
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth',
'keypointrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth',
"keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
"keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
}
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=2, num_keypoints=17,
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
def keypointrcnn_resnet50_fpn(
pretrained=False,
progress=True,
num_classes=2,
num_keypoints=17,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
"""
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
......@@ -331,19 +361,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained:
key = 'keypointrcnn_resnet50_fpn_coco'
if pretrained == 'legacy':
key += '_legacy'
state_dict = load_state_dict_from_url(model_urls[key],
progress=progress)
key = "keypointrcnn_resnet50_fpn_coco"
if pretrained == "legacy":
key += "_legacy"
state_dict = load_state_dict_from_url(model_urls[key], progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
from collections import OrderedDict
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url
from .faster_rcnn import FasterRCNN
from ._utils import overwrite_eps
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .faster_rcnn import FasterRCNN
__all__ = [
"MaskRCNN", "maskrcnn_resnet50_fpn",
"MaskRCNN",
"maskrcnn_resnet50_fpn",
]
......@@ -149,26 +148,46 @@ class MaskRCNN(FasterRCNN):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None,
# Mask parameters
mask_roi_pool=None, mask_head=None, mask_predictor=None):
def __init__(
self,
backbone,
num_classes=None,
# transform parameters
min_size=800,
max_size=1333,
image_mean=None,
image_std=None,
# RPN parameters
rpn_anchor_generator=None,
rpn_head=None,
rpn_pre_nms_top_n_train=2000,
rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000,
rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7,
rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
# Mask parameters
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
......@@ -179,10 +198,7 @@ class MaskRCNN(FasterRCNN):
out_channels = backbone.out_channels
if mask_roi_pool is None:
mask_roi_pool = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=14,
sampling_ratio=2)
mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
if mask_head is None:
mask_layers = (256, 256, 256, 256)
......@@ -192,28 +208,42 @@ class MaskRCNN(FasterRCNN):
if mask_predictor is None:
mask_predictor_in_channels = 256 # == mask_layers[-1]
mask_dim_reduced = 256
mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels,
mask_dim_reduced, num_classes)
mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
super(MaskRCNN, self).__init__(
backbone, num_classes,
backbone,
num_classes,
# transform parameters
min_size, max_size,
image_mean, image_std,
min_size,
max_size,
image_mean,
image_std,
# RPN-specific parameters
rpn_anchor_generator, rpn_head,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test,
rpn_anchor_generator,
rpn_head,
rpn_pre_nms_top_n_train,
rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train,
rpn_post_nms_top_n_test,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_fg_iou_thresh,
rpn_bg_iou_thresh,
rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_score_thresh,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
bbox_reg_weights)
box_roi_pool,
box_head,
box_predictor,
box_score_thresh,
box_nms_thresh,
box_detections_per_img,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
)
self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head
......@@ -232,8 +262,8 @@ class MaskRCNNHeads(nn.Sequential):
next_feature = in_channels
for layer_idx, layer_features in enumerate(layers, 1):
d["mask_fcn{}".format(layer_idx)] = nn.Conv2d(
next_feature, layer_features, kernel_size=3,
stride=1, padding=dilation, dilation=dilation)
next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation
)
d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True)
next_feature = layer_features
......@@ -247,11 +277,15 @@ class MaskRCNNHeads(nn.Sequential):
class MaskRCNNPredictor(nn.Sequential):
def __init__(self, in_channels, dim_reduced, num_classes):
super(MaskRCNNPredictor, self).__init__(OrderedDict([
("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
("relu", nn.ReLU(inplace=True)),
("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
]))
super(MaskRCNNPredictor, self).__init__(
OrderedDict(
[
("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
("relu", nn.ReLU(inplace=True)),
("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
]
)
)
for name, param in self.named_parameters():
if "weight" in name:
......@@ -261,13 +295,13 @@ class MaskRCNNPredictor(nn.Sequential):
model_urls = {
'maskrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth',
"maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
}
def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
def maskrcnn_resnet50_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
"""
Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.
......@@ -324,16 +358,16 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
import math
from collections import OrderedDict
import warnings
from collections import OrderedDict
from typing import Dict, List, Tuple, Optional
import torch
from torch import nn, Tensor
from typing import Dict, List, Tuple, Optional
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
from ...ops.feature_pyramid_network import LastLevelP6P7
from . import _utils as det_utils
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from ...ops.feature_pyramid_network import LastLevelP6P7
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
from .transform import GeneralizedRCNNTransform
__all__ = [
"RetinaNet", "retinanet_resnet50_fpn"
]
__all__ = ["RetinaNet", "retinanet_resnet50_fpn"]
def _sum(x: List[Tensor]) -> Tensor:
......@@ -48,16 +45,13 @@ class RetinaNetHead(nn.Module):
def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
return {
'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
"classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
"bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
}
def forward(self, x):
# type: (List[Tensor]) -> Dict[str, Tensor]
return {
'cls_logits': self.classification_head(x),
'bbox_regression': self.regression_head(x)
}
return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
class RetinaNetClassificationHead(nn.Module):
......@@ -100,7 +94,7 @@ class RetinaNetClassificationHead(nn.Module):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
losses = []
cls_logits = head_outputs['cls_logits']
cls_logits = head_outputs["cls_logits"]
for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
# determine only the foreground
......@@ -111,18 +105,21 @@ class RetinaNetClassificationHead(nn.Module):
gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[
foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]]
targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
] = 1.0
# find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
# compute the classification loss
losses.append(sigmoid_focal_loss(
cls_logits_per_image[valid_idxs_per_image],
gt_classes_target[valid_idxs_per_image],
reduction='sum',
) / max(1, num_foreground))
losses.append(
sigmoid_focal_loss(
cls_logits_per_image[valid_idxs_per_image],
gt_classes_target[valid_idxs_per_image],
reduction="sum",
)
/ max(1, num_foreground)
)
return _sum(losses) / len(targets)
......@@ -153,8 +150,9 @@ class RetinaNetRegressionHead(nn.Module):
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
"box_coder": det_utils.BoxCoder,
}
def __init__(self, in_channels, num_anchors):
......@@ -181,16 +179,17 @@ class RetinaNetRegressionHead(nn.Module):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
losses = []
bbox_regression = head_outputs['bbox_regression']
bbox_regression = head_outputs["bbox_regression"]
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \
zip(targets, bbox_regression, anchors, matched_idxs):
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
targets, bbox_regression, anchors, matched_idxs
):
# determine only the foreground indices, ignore the rest
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_foreground = foreground_idxs_per_image.numel()
# select only the foreground boxes
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]]
matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
......@@ -198,11 +197,10 @@ class RetinaNetRegressionHead(nn.Module):
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
# compute the loss
losses.append(torch.nn.functional.l1_loss(
bbox_regression_per_image,
target_regression,
reduction='sum'
) / max(1, num_foreground))
losses.append(
torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
/ max(1, num_foreground)
)
return _sum(losses) / max(1, len(targets))
......@@ -309,30 +307,40 @@ class RetinaNet(nn.Module):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
}
def __init__(self, backbone, num_classes,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# Anchor parameters
anchor_generator=None, head=None,
proposal_matcher=None,
score_thresh=0.05,
nms_thresh=0.5,
detections_per_img=300,
fg_iou_thresh=0.5, bg_iou_thresh=0.4,
topk_candidates=1000):
def __init__(
self,
backbone,
num_classes,
# transform parameters
min_size=800,
max_size=1333,
image_mean=None,
image_std=None,
# Anchor parameters
anchor_generator=None,
head=None,
proposal_matcher=None,
score_thresh=0.05,
nms_thresh=0.5,
detections_per_img=300,
fg_iou_thresh=0.5,
bg_iou_thresh=0.4,
topk_candidates=1000,
):
super().__init__()
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
"same for all the levels)"
)
self.backbone = backbone
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
......@@ -340,9 +348,7 @@ class RetinaNet(nn.Module):
if anchor_generator is None:
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
self.anchor_generator = anchor_generator
if head is None:
......@@ -385,20 +391,21 @@ class RetinaNet(nn.Module):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64,
device=anchors_per_image.device))
if targets_per_image["boxes"].numel() == 0:
matched_idxs.append(
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
)
continue
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
matched_idxs.append(self.proposal_matcher(match_quality_matrix))
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
class_logits = head_outputs['cls_logits']
box_regression = head_outputs['bbox_regression']
class_logits = head_outputs["cls_logits"]
box_regression = head_outputs["bbox_regression"]
num_images = len(image_shapes)
......@@ -413,8 +420,9 @@ class RetinaNet(nn.Module):
image_scores = []
image_labels = []
for box_regression_per_level, logits_per_level, anchors_per_level in \
zip(box_regression_per_image, logits_per_image, anchors_per_image):
for box_regression_per_level, logits_per_level, anchors_per_level in zip(
box_regression_per_image, logits_per_image, anchors_per_image
):
num_classes = logits_per_level.shape[-1]
# remove low scoring boxes
......@@ -428,11 +436,12 @@ class RetinaNet(nn.Module):
scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs]
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode='floor')
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
labels_per_level = topk_idxs % num_classes
boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs],
anchors_per_level[anchor_idxs])
boxes_per_level = self.box_coder.decode_single(
box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
)
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
image_boxes.append(boxes_per_level)
......@@ -445,13 +454,15 @@ class RetinaNet(nn.Module):
# non-maximum suppression
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
keep = keep[:self.detections_per_img]
detections.append({
'boxes': image_boxes[keep],
'scores': image_scores[keep],
'labels': image_labels[keep],
})
keep = keep[: self.detections_per_img]
detections.append(
{
"boxes": image_boxes[keep],
"scores": image_scores[keep],
"labels": image_labels[keep],
}
)
return detections
......@@ -478,12 +489,11 @@ class RetinaNet(nn.Module):
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
raise ValueError(
"Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape)
)
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))
# get the original image sizes
original_image_sizes: List[Tuple[int, int]] = []
......@@ -505,14 +515,15 @@ class RetinaNet(nn.Module):
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
raise ValueError(
"All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}.".format(degen_bb, target_idx)
)
# get the features from the backbone
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
features = OrderedDict([("0", features)])
# TODO: Do we want a list or a dict?
features = list(features.values())
......@@ -536,7 +547,7 @@ class RetinaNet(nn.Module):
HW = 0
for v in num_anchors_per_level:
HW += v
HWA = head_outputs['cls_logits'].size(1)
HWA = head_outputs["cls_logits"].size(1)
A = HWA // HW
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
......@@ -559,13 +570,13 @@ class RetinaNet(nn.Module):
model_urls = {
'retinanet_resnet50_fpn_coco':
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth',
"retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
}
def retinanet_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
def retinanet_resnet50_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
"""
Constructs a RetinaNet model with a ResNet-50-FPN backbone.
......@@ -613,18 +624,23 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
# skip P2 because it generates too many anchors (according to their paper)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(256, 256), trainable_layers=trainable_backbone_layers)
backbone = resnet_fpn_backbone(
"resnet50",
pretrained_backbone,
returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(256, 256),
trainable_layers=trainable_backbone_layers,
)
model = RetinaNet(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
import torch
import torchvision
from typing import Optional, List, Dict, Tuple
import torch
import torch.nn.functional as F
import torchvision
from torch import nn, Tensor
from torchvision.ops import boxes as box_ops
from torchvision.ops import roi_align
from . import _utils as det_utils
from typing import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
......@@ -46,7 +43,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction='sum',
reduction="sum",
)
box_loss = box_loss / labels.numel()
......@@ -95,7 +92,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
matched_idxs = matched_idxs.to(boxes)
rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
gt_masks = gt_masks[:, None].to(rois)
return roi_align(gt_masks, rois, (M, M), 1.)[:, 0]
return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
......@@ -113,8 +110,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
discretization_size = mask_logits.shape[-1]
labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
mask_targets = [
project_masks_on_boxes(m, p, i, discretization_size)
for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
]
labels = torch.cat(labels, dim=0)
......@@ -167,59 +163,72 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
return heatmaps, valid
def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
widths_i, heights_i, offset_x_i, offset_y_i):
def _onnx_heatmaps_to_keypoints(
maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
):
num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height
roi_map = F.interpolate(
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0]
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
)[:, 0]
w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = (pos % w)
y_int = ((pos - x_int) // w)
x_int = pos % w
y_int = (pos - x_int) // w
x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * \
width_correction.to(dtype=torch.float32)
y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * \
height_correction.to(dtype=torch.float32)
x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
dtype=torch.float32
)
y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
dtype=torch.float32
)
xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
xy_preds_i = torch.stack([xy_preds_i_0.to(dtype=torch.float32),
xy_preds_i_1.to(dtype=torch.float32),
xy_preds_i_2.to(dtype=torch.float32)], 0)
xy_preds_i = torch.stack(
[
xy_preds_i_0.to(dtype=torch.float32),
xy_preds_i_1.to(dtype=torch.float32),
xy_preds_i_2.to(dtype=torch.float32),
],
0,
)
# TODO: simplify when indexing without rank will be supported by ONNX
base = num_keypoints * num_keypoints + num_keypoints + 1
ind = torch.arange(num_keypoints)
ind = ind.to(dtype=torch.int64) * base
end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \
.index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64))
end_scores_i = (
roi_map.index_select(1, y_int.to(dtype=torch.int64))
.index_select(2, x_int.to(dtype=torch.int64))
.view(-1)
.index_select(0, ind.to(dtype=torch.int64))
)
return xy_preds_i, end_scores_i
@torch.jit._script_if_tracing
def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil,
widths, heights, offset_x, offset_y, num_keypoints):
def _onnx_heatmaps_to_keypoints_loop(
maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
):
xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
for i in range(int(rois.size(0))):
xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(maps, maps[i],
widths_ceil[i], heights_ceil[i],
widths[i], heights[i],
offset_x[i], offset_y[i])
xy_preds = torch.cat((xy_preds.to(dtype=torch.float32),
xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
end_scores = torch.cat((end_scores.to(dtype=torch.float32),
end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0)
xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
)
xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
end_scores = torch.cat(
(end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
)
return xy_preds, end_scores
......@@ -246,10 +255,17 @@ def heatmaps_to_keypoints(maps, rois):
num_keypoints = maps.shape[1]
if torchvision._is_tracing():
xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(maps, rois,
widths_ceil, heights_ceil, widths, heights,
offset_x, offset_y,
torch.scalar_tensor(num_keypoints, dtype=torch.int64))
xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
maps,
rois,
widths_ceil,
heights_ceil,
widths,
heights,
offset_x,
offset_y,
torch.scalar_tensor(num_keypoints, dtype=torch.int64),
)
return xy_preds.permute(0, 2, 1), end_scores
xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
......@@ -260,13 +276,14 @@ def heatmaps_to_keypoints(maps, rois):
width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height
roi_map = F.interpolate(
maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0]
maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
)[:, 0]
# roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = pos % w
y_int = torch.div(pos - x_int, w, rounding_mode='floor')
y_int = torch.div(pos - x_int, w, rounding_mode="floor")
# assert (roi_map_probs[k, y_int, x_int] ==
# roi_map_probs[k, :, :].max())
x = (x_int.float() + 0.5) * width_correction
......@@ -288,9 +305,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
valid = []
for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
kp = gt_kp_in_image[midx]
heatmaps_per_image, valid_per_image = keypoints_to_heatmap(
kp, proposals_per_image, discretization_size
)
heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
heatmaps.append(heatmaps_per_image.view(-1))
valid.append(valid_per_image.view(-1))
......@@ -327,10 +342,10 @@ def keypointrcnn_inference(x, boxes):
def _onnx_expand_boxes(boxes, scale):
# type: (Tensor, float) -> Tensor
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
w_half = w_half.to(dtype=torch.float32) * scale
h_half = h_half.to(dtype=torch.float32) * scale
......@@ -350,10 +365,10 @@ def expand_boxes(boxes, scale):
# type: (Tensor, float) -> Tensor
if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
w_half *= scale
h_half *= scale
......@@ -395,7 +410,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, -1, -1))
# Resize mask
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
mask = mask[0][0]
im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
......@@ -404,9 +419,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, im_h)
im_mask[y_0:y_1, x_0:x_1] = mask[
(y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0])
]
im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
return im_mask
......@@ -414,8 +427,8 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
one = torch.ones(1, dtype=torch.int64)
zero = torch.zeros(1, dtype=torch.int64)
w = (box[2] - box[0] + one)
h = (box[3] - box[1] + one)
w = box[2] - box[0] + one
h = box[3] - box[1] + one
w = torch.max(torch.cat((w, one)))
h = torch.max(torch.cat((h, one)))
......@@ -423,7 +436,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
# Resize mask
mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
mask = mask[0][0]
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
......@@ -431,23 +444,18 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]),
(x_0 - box[0]):(x_1 - box[0])]
unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
# TODO : replace below with a dynamic padding when support is added in ONNX
# pad y
zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
concat_0 = torch.cat((zeros_y0,
unpaded_im_mask.to(dtype=torch.float32),
zeros_y1), 0)[0:im_h, :]
concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
# pad x
zeros_x0 = torch.zeros(concat_0.size(0), x_0)
zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
im_mask = torch.cat((zeros_x0,
concat_0,
zeros_x1), 1)[:, :im_w]
im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
return im_mask
......@@ -468,13 +476,10 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
im_h, im_w = img_shape
if torchvision._is_tracing():
return _onnx_paste_masks_in_image_loop(masks, boxes,
torch.scalar_tensor(im_h, dtype=torch.int64),
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None]
res = [
paste_mask_in_image(m[0], b, im_h, im_w)
for m, b in zip(masks, boxes)
]
return _onnx_paste_masks_in_image_loop(
masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
)[:, None]
res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
if len(res) > 0:
ret = torch.stack(res, dim=0)[:, None]
else:
......@@ -484,46 +489,44 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
class RoIHeads(nn.Module):
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
}
def __init__(self,
box_roi_pool,
box_head,
box_predictor,
# Faster R-CNN training
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
bbox_reg_weights,
# Faster R-CNN inference
score_thresh,
nms_thresh,
detections_per_img,
# Mask
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
):
def __init__(
self,
box_roi_pool,
box_head,
box_predictor,
# Faster R-CNN training
fg_iou_thresh,
bg_iou_thresh,
batch_size_per_image,
positive_fraction,
bbox_reg_weights,
# Faster R-CNN inference
score_thresh,
nms_thresh,
detections_per_img,
# Mask
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
):
super(RoIHeads, self).__init__()
self.box_similarity = box_ops.box_iou
# assign ground-truth boxes for each proposal
self.proposal_matcher = det_utils.Matcher(
fg_iou_thresh,
bg_iou_thresh,
allow_low_quality_matches=False)
self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
batch_size_per_image,
positive_fraction)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
if bbox_reg_weights is None:
bbox_reg_weights = (10., 10., 5., 5.)
bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
self.box_roi_pool = box_roi_pool
......@@ -572,9 +575,7 @@ class RoIHeads(nn.Module):
clamped_matched_idxs_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
labels_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
else:
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
......@@ -601,19 +602,14 @@ class RoIHeads(nn.Module):
# type: (List[Tensor]) -> List[Tensor]
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
zip(sampled_pos_inds, sampled_neg_inds)
):
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
sampled_inds.append(img_sampled_inds)
return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
proposals = [
torch.cat((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes)
]
proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
return proposals
......@@ -625,10 +621,11 @@ class RoIHeads(nn.Module):
if self.has_mask():
assert all(["masks" in t for t in targets])
def select_training_samples(self,
proposals, # type: List[Tensor]
targets # type: Optional[List[Dict[str, Tensor]]]
):
def select_training_samples(
self,
proposals, # type: List[Tensor]
targets, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
self.check_targets(targets)
assert targets is not None
......@@ -661,12 +658,13 @@ class RoIHeads(nn.Module):
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
return proposals, matched_idxs, labels, regression_targets
def postprocess_detections(self,
class_logits, # type: Tensor
box_regression, # type: Tensor
proposals, # type: List[Tensor]
image_shapes # type: List[Tuple[int, int]]
):
def postprocess_detections(
self,
class_logits, # type: Tensor
box_regression, # type: Tensor
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
device = class_logits.device
num_classes = class_logits.shape[-1]
......@@ -710,7 +708,7 @@ class RoIHeads(nn.Module):
# non-maximum suppression, independently done per class
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[:self.detections_per_img]
keep = keep[: self.detections_per_img]
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
all_boxes.append(boxes)
......@@ -719,12 +717,13 @@ class RoIHeads(nn.Module):
return all_boxes, all_scores, all_labels
def forward(self,
features, # type: Dict[str, Tensor]
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
def forward(
self,
features, # type: Dict[str, Tensor]
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
"""
Args:
......@@ -737,10 +736,10 @@ class RoIHeads(nn.Module):
for t in targets:
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type'
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type'
assert t["boxes"].dtype in floating_point_types, "target boxes must of float type"
assert t["labels"].dtype == torch.int64, "target labels must of int64 type"
if self.has_keypoint():
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type'
assert t["keypoints"].dtype == torch.float32, "target keypoints must of float type"
if self.training:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
......@@ -757,12 +756,8 @@ class RoIHeads(nn.Module):
losses = {}
if self.training:
assert labels is not None and regression_targets is not None
loss_classifier, loss_box_reg = fastrcnn_loss(
class_logits, box_regression, labels, regression_targets)
losses = {
"loss_classifier": loss_classifier,
"loss_box_reg": loss_box_reg
}
loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes)
......@@ -805,12 +800,8 @@ class RoIHeads(nn.Module):
gt_masks = [t["masks"] for t in targets]
gt_labels = [t["labels"] for t in targets]
rcnn_loss_mask = maskrcnn_loss(
mask_logits, mask_proposals,
gt_masks, gt_labels, pos_matched_idxs)
loss_mask = {
"loss_mask": rcnn_loss_mask
}
rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
loss_mask = {"loss_mask": rcnn_loss_mask}
else:
labels = [r["labels"] for r in result]
masks_probs = maskrcnn_inference(mask_logits, labels)
......@@ -821,8 +812,11 @@ class RoIHeads(nn.Module):
# keep none checks in if conditional so torchscript will conditionally
# compile each branch
if self.keypoint_roi_pool is not None and self.keypoint_head is not None \
and self.keypoint_predictor is not None:
if (
self.keypoint_roi_pool is not None
and self.keypoint_head is not None
and self.keypoint_predictor is not None
):
keypoint_proposals = [p["boxes"] for p in result]
if self.training:
# during training, only focus on positive boxes
......@@ -848,11 +842,9 @@ class RoIHeads(nn.Module):
gt_keypoints = [t["keypoints"] for t in targets]
rcnn_loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals,
gt_keypoints, pos_matched_idxs)
loss_keypoint = {
"loss_keypoint": rcnn_loss_keypoint
}
keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
)
loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
else:
assert keypoint_logits is not None
assert keypoint_proposals is not None
......
import torch
from torch.nn import functional as F
from torch import nn, Tensor
from typing import List, Optional, Dict, Tuple
import torch
import torchvision
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops
from . import _utils as det_utils
from .image_list import ImageList
from typing import List, Optional, Dict, Tuple
# Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator
from .image_list import ImageList
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(torch.cat(
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
num_anchors), 0))
pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
return num_anchors, pre_nms_top_n
......@@ -37,13 +35,9 @@ class RPNHead(nn.Module):
def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=1, stride=1
)
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
for layer in self.children():
torch.nn.init.normal_(layer.weight, std=0.01)
......@@ -76,21 +70,15 @@ def concat_box_prediction_layers(box_cls, box_regression):
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip(
box_cls, box_regression
):
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(
box_cls_per_level, N, A, C, H, W
)
box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W
)
box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
......@@ -125,22 +113,30 @@ class RegionProposalNetwork(torch.nn.Module):
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
'pre_nms_top_n': Dict[str, int],
'post_nms_top_n': Dict[str, int],
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
"pre_nms_top_n": Dict[str, int],
"post_nms_top_n": Dict[str, int],
}
def __init__(self,
anchor_generator,
head,
#
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
#
pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
def __init__(
self,
anchor_generator,
head,
#
fg_iou_thresh,
bg_iou_thresh,
batch_size_per_image,
positive_fraction,
#
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
score_thresh=0.0,
):
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator
self.head = head
......@@ -155,9 +151,7 @@ class RegionProposalNetwork(torch.nn.Module):
allow_low_quality_matches=True,
)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
batch_size_per_image, positive_fraction
)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
# used during testing
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
......@@ -167,13 +161,13 @@ class RegionProposalNetwork(torch.nn.Module):
def pre_nms_top_n(self):
if self.training:
return self._pre_nms_top_n['training']
return self._pre_nms_top_n['testing']
return self._pre_nms_top_n["training"]
return self._pre_nms_top_n["testing"]
def post_nms_top_n(self):
if self.training:
return self._post_nms_top_n['training']
return self._post_nms_top_n['testing']
return self._post_nms_top_n["training"]
return self._post_nms_top_n["testing"]
def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
......@@ -235,8 +229,7 @@ class RegionProposalNetwork(torch.nn.Module):
objectness = objectness.reshape(num_images, -1)
levels = [
torch.full((n,), idx, dtype=torch.int64, device=device)
for idx, n in enumerate(num_anchors_per_level)
torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
]
levels = torch.cat(levels, 0)
levels = levels.reshape(1, -1).expand_as(objectness)
......@@ -271,7 +264,7 @@ class RegionProposalNetwork(torch.nn.Module):
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[:self.post_nms_top_n()]
keep = keep[: self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes)
......@@ -303,24 +296,26 @@ class RegionProposalNetwork(torch.nn.Module):
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = F.smooth_l1_loss(
pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1 / 9,
reduction='sum',
) / (sampled_inds.numel())
objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
box_loss = (
F.smooth_l1_loss(
pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1 / 9,
reduction="sum",
)
/ (sampled_inds.numel())
)
objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
return objectness_loss, box_loss
def forward(self,
images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
def forward(
self,
images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
"""
Args:
......@@ -346,8 +341,7 @@ class RegionProposalNetwork(torch.nn.Module):
num_images = len(anchors)
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas)
objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
# note that we detach the deltas because Faster R-CNN do not backprop through
# the proposals
......@@ -361,7 +355,8 @@ class RegionProposalNetwork(torch.nn.Module):
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss(
objectness, pred_bbox_deltas, labels, regression_targets)
objectness, pred_bbox_deltas, labels, regression_targets
)
losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
......
import torch
import torch.nn.functional as F
import warnings
from collections import OrderedDict
from torch import nn, Tensor
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import boxes as box_ops
from .. import vgg
from . import _utils as det_utils
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .transform import GeneralizedRCNNTransform
from .. import vgg
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import boxes as box_ops
__all__ = ['SSD', 'ssd300_vgg16']
__all__ = ["SSD", "ssd300_vgg16"]
model_urls = {
'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth',
"ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
}
backbone_urls = {
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
'vgg16_features': 'https://download.pytorch.org/models/vgg16_features-amdegroot.pth'
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth"
}
......@@ -43,8 +43,8 @@ class SSDHead(nn.Module):
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
return {
'bbox_regression': self.regression_head(x),
'cls_logits': self.classification_head(x),
"bbox_regression": self.regression_head(x),
"cls_logits": self.classification_head(x),
}
......@@ -159,31 +159,38 @@ class SSD(nn.Module):
proposals used during the training of the classification head. It is used to estimate the negative to
positive ratio.
"""
__annotations__ = {
'box_coder': det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher,
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
}
def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator,
size: Tuple[int, int], num_classes: int,
image_mean: Optional[List[float]] = None, image_std: Optional[List[float]] = None,
head: Optional[nn.Module] = None,
score_thresh: float = 0.01,
nms_thresh: float = 0.45,
detections_per_img: int = 200,
iou_thresh: float = 0.5,
topk_candidates: int = 400,
positive_fraction: float = 0.25):
def __init__(
self,
backbone: nn.Module,
anchor_generator: DefaultBoxGenerator,
size: Tuple[int, int],
num_classes: int,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
head: Optional[nn.Module] = None,
score_thresh: float = 0.01,
nms_thresh: float = 0.45,
detections_per_img: int = 200,
iou_thresh: float = 0.5,
topk_candidates: int = 400,
positive_fraction: float = 0.25,
):
super().__init__()
self.backbone = backbone
self.anchor_generator = anchor_generator
self.box_coder = det_utils.BoxCoder(weights=(10., 10., 5., 5.))
self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
if head is None:
if hasattr(backbone, 'out_channels'):
if hasattr(backbone, "out_channels"):
out_channels = backbone.out_channels
else:
out_channels = det_utils.retrieve_out_channels(backbone, size)
......@@ -200,8 +207,9 @@ class SSD(nn.Module):
image_mean = [0.485, 0.456, 0.406]
if image_std is None:
image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min(size), max(size), image_mean, image_std,
size_divisible=1, fixed_size=size)
self.transform = GeneralizedRCNNTransform(
min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size
)
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
......@@ -213,45 +221,58 @@ class SSD(nn.Module):
self._has_warned = False
@torch.jit.unused
def eager_outputs(self, losses: Dict[str, Tensor],
detections: List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
def eager_outputs(
self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training:
return losses
return detections
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor],
matched_idxs: List[Tensor]) -> Dict[str, Tensor]:
bbox_regression = head_outputs['bbox_regression']
cls_logits = head_outputs['cls_logits']
def compute_loss(
self,
targets: List[Dict[str, Tensor]],
head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
matched_idxs: List[Tensor],
) -> Dict[str, Tensor]:
bbox_regression = head_outputs["bbox_regression"]
cls_logits = head_outputs["cls_logits"]
# Match original targets with default boxes
num_foreground = 0
bbox_loss = []
cls_targets = []
for (targets_per_image, bbox_regression_per_image, cls_logits_per_image, anchors_per_image,
matched_idxs_per_image) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
for (
targets_per_image,
bbox_regression_per_image,
cls_logits_per_image,
anchors_per_image,
matched_idxs_per_image,
) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
# produce the matching between boxes and targets
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image]
num_foreground += foreground_matched_idxs_per_image.numel()
# Calculate regression loss
matched_gt_boxes_per_image = targets_per_image['boxes'][foreground_matched_idxs_per_image]
matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
bbox_loss.append(torch.nn.functional.smooth_l1_loss(
bbox_regression_per_image,
target_regression,
reduction='sum'
))
bbox_loss.append(
torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
)
# Estimate ground truth for class targets
gt_classes_target = torch.zeros((cls_logits_per_image.size(0), ), dtype=targets_per_image['labels'].dtype,
device=targets_per_image['labels'].device)
gt_classes_target[foreground_idxs_per_image] = \
targets_per_image['labels'][foreground_matched_idxs_per_image]
gt_classes_target = torch.zeros(
(cls_logits_per_image.size(0),),
dtype=targets_per_image["labels"].dtype,
device=targets_per_image["labels"].device,
)
gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][
foreground_matched_idxs_per_image
]
cls_targets.append(gt_classes_target)
bbox_loss = torch.stack(bbox_loss)
......@@ -259,30 +280,29 @@ class SSD(nn.Module):
# Calculate classification loss
num_classes = cls_logits.size(-1)
cls_loss = F.cross_entropy(
cls_logits.view(-1, num_classes),
cls_targets.view(-1),
reduction='none'
).view(cls_targets.size())
cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view(
cls_targets.size()
)
# Hard Negative Sampling
foreground_idxs = cls_targets > 0
num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True)
# num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio
negative_loss = cls_loss.clone()
negative_loss[foreground_idxs] = -float('inf') # use -inf to detect positive values that creeped in the sample
negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample
values, idx = negative_loss.sort(1, descending=True)
# background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values))
background_idxs = idx.sort(1)[1] < num_negative
N = max(1, num_foreground)
return {
'bbox_regression': bbox_loss.sum() / N,
'classification': (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
"bbox_regression": bbox_loss.sum() / N,
"classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
}
def forward(self, images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
def forward(
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training and targets is None:
raise ValueError("In training mode, targets should be passed")
......@@ -292,12 +312,11 @@ class SSD(nn.Module):
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
raise ValueError(
"Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape)
)
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))
# get the original image sizes
original_image_sizes: List[Tuple[int, int]] = []
......@@ -317,14 +336,15 @@ class SSD(nn.Module):
if degenerate_boxes.any():
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
raise ValueError(
"All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}.".format(degen_bb, target_idx)
)
# get the features from the backbone
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
features = OrderedDict([("0", features)])
features = list(features.values())
......@@ -341,12 +361,13 @@ class SSD(nn.Module):
matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0:
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64,
device=anchors_per_image.device))
if targets_per_image["boxes"].numel() == 0:
matched_idxs.append(
torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
)
continue
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image)
match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
matched_idxs.append(self.proposal_matcher(match_quality_matrix))
losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs)
......@@ -361,10 +382,11 @@ class SSD(nn.Module):
return losses, detections
return self.eager_outputs(losses, detections)
def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor],
image_shapes: List[Tuple[int, int]]) -> List[Dict[str, Tensor]]:
bbox_regression = head_outputs['bbox_regression']
pred_scores = F.softmax(head_outputs['cls_logits'], dim=-1)
def postprocess_detections(
self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]]
) -> List[Dict[str, Tensor]]:
bbox_regression = head_outputs["bbox_regression"]
pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
num_classes = pred_scores.size(-1)
device = pred_scores.device
......@@ -400,13 +422,15 @@ class SSD(nn.Module):
# non-maximum suppression
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
keep = keep[:self.detections_per_img]
detections.append({
'boxes': image_boxes[keep],
'scores': image_scores[keep],
'labels': image_labels[keep],
})
keep = keep[: self.detections_per_img]
detections.append(
{
"boxes": image_boxes[keep],
"scores": image_scores[keep],
"labels": image_labels[keep],
}
)
return detections
......@@ -423,45 +447,47 @@ class SSDFeatureExtractorVGG(nn.Module):
self.scale_weight = nn.Parameter(torch.ones(512) * 20)
# Multiple Feature maps - page 4, Fig 2 of SSD paper
self.features = nn.Sequential(
*backbone[:maxpool4_pos] # until conv4_3
)
self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3
# SSD300 case - page 4, Fig 2 of SSD paper
extra = nn.ModuleList([
nn.Sequential(
nn.Conv2d(1024, 256, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(512, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3), # conv10_2
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3), # conv11_2
nn.ReLU(inplace=True),
)
])
extra = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(1024, 256, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(512, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3), # conv10_2
nn.ReLU(inplace=True),
),
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3), # conv11_2
nn.ReLU(inplace=True),
),
]
)
if highres:
# Additional layers for the SSD512 case. See page 11, footernote 5.
extra.append(nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=4), # conv12_2
nn.ReLU(inplace=True),
))
extra.append(
nn.Sequential(
nn.Conv2d(256, 128, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=4), # conv12_2
nn.ReLU(inplace=True),
)
)
_xavier_init(extra)
fc = nn.Sequential(
......@@ -469,13 +495,16 @@ class SSDFeatureExtractorVGG(nn.Module):
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
)
_xavier_init(fc)
extra.insert(0, nn.Sequential(
*backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5
fc,
))
extra.insert(
0,
nn.Sequential(
*backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5
fc,
),
)
self.extra = extra
def forward(self, x: Tensor) -> Dict[str, Tensor]:
......@@ -495,7 +524,7 @@ class SSDFeatureExtractorVGG(nn.Module):
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
if backbone_name in backbone_urls:
# Use custom backbones more appropriate for SSD
arch = backbone_name.split('_')[0]
arch = backbone_name.split("_")[0]
backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features
if pretrained:
state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress)
......@@ -519,8 +548,14 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained
return SSDFeatureExtractorVGG(backbone, highres)
def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any):
def ssd300_vgg16(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
):
"""Constructs an SSD model with input size 300x300 and a VGG16 backbone.
Reference: `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
......@@ -569,16 +604,19 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
warnings.warn("The size of the model is already fixed; ignoring the argument.")
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5)
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5
)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
steps=[8, 16, 32, 64, 100, 300])
anchor_generator = DefaultBoxGenerator(
[[2], [2, 3], [2, 3], [2, 3], [2], [2]],
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
steps=[8, 16, 32, 64, 100, 300],
)
defaults = {
# Rescale the input in a way compatible to the backbone
......@@ -588,7 +626,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if pretrained:
weights_name = 'ssd300_vgg16_coco'
weights_name = "ssd300_vgg16_coco"
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
......
import torch
import warnings
from collections import OrderedDict
from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from .. import mobilenet
from . import _utils as det_utils
from .ssd import SSD, SSDScoringHead
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
from .. import mobilenet
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from .ssd import SSD, SSDScoringHead
__all__ = ['ssdlite320_mobilenet_v3_large']
__all__ = ["ssdlite320_mobilenet_v3_large"]
model_urls = {
'ssdlite320_mobilenet_v3_large_coco':
'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth'
"ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth"
}
# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper
def _prediction_block(in_channels: int, out_channels: int, kernel_size: int,
norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
def _prediction_block(
in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]
) -> nn.Sequential:
return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1
ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=nn.ReLU6),
ConvNormActivation(
in_channels,
in_channels,
kernel_size=kernel_size,
groups=in_channels,
norm_layer=norm_layer,
activation_layer=nn.ReLU6,
),
# 1x1 projetion to output channels
nn.Conv2d(in_channels, out_channels, 1)
nn.Conv2d(in_channels, out_channels, 1),
)
......@@ -41,16 +46,23 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2
return nn.Sequential(
# 1x1 projection to half output channels
ConvNormActivation(in_channels, intermediate_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(
in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
),
# 3x3 depthwise with stride 2 and padding 1
ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2,
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(
intermediate_channels,
intermediate_channels,
kernel_size=3,
stride=2,
groups=intermediate_channels,
norm_layer=norm_layer,
activation_layer=activation,
),
# 1x1 projetion to output channels
ConvNormActivation(intermediate_channels, out_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation),
ConvNormActivation(
intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
),
)
......@@ -63,22 +75,24 @@ def _normal_init(conv: nn.Module):
class SSDLiteHead(nn.Module):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int,
norm_layer: Callable[..., nn.Module]):
def __init__(
self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
):
super().__init__()
self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
return {
'bbox_regression': self.regression_head(x),
'cls_logits': self.classification_head(x),
"bbox_regression": self.regression_head(x),
"cls_logits": self.classification_head(x),
}
class SSDLiteClassificationHead(SSDScoringHead):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int,
norm_layer: Callable[..., nn.Module]):
def __init__(
self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
):
cls_logits = nn.ModuleList()
for channels, anchors in zip(in_channels, num_anchors):
cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
......@@ -96,24 +110,33 @@ class SSDLiteRegressionHead(SSDScoringHead):
class SSDLiteFeatureExtractorMobileNet(nn.Module):
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], width_mult: float = 1.0,
min_depth: int = 16, **kwargs: Any):
def __init__(
self,
backbone: nn.Module,
c4_pos: int,
norm_layer: Callable[..., nn.Module],
width_mult: float = 1.0,
min_depth: int = 16,
**kwargs: Any,
):
super().__init__()
assert not backbone[c4_pos].use_res_connect
self.features = nn.Sequential(
# As described in section 6.3 of MobileNetV3 paper
nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer
nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end
nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]), # from C4 depthwise until end
)
get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731
extra = nn.ModuleList([
_extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
_extra_block(get_depth(512), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(128), norm_layer),
])
extra = nn.ModuleList(
[
_extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
_extra_block(get_depth(512), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(128), norm_layer),
]
)
_normal_init(extra)
self.extra = extra
......@@ -132,10 +155,17 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module):
return OrderedDict([(str(i), v) for i, v in enumerate(output)])
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int,
norm_layer: Callable[..., nn.Module], **kwargs: Any):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress,
norm_layer=norm_layer, **kwargs).features
def _mobilenet_extractor(
backbone_name: str,
progress: bool,
pretrained: bool,
trainable_layers: int,
norm_layer: Callable[..., nn.Module],
**kwargs: Any,
):
backbone = mobilenet.__dict__[backbone_name](
pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs
).features
if not pretrained:
# Change the default initialization scheme if not pretrained
_normal_init(backbone)
......@@ -156,10 +186,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91,
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any):
def ssdlite320_mobilenet_v3_large(
pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = False,
trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
):
"""Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at
`"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_ and
......@@ -188,7 +223,8 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
warnings.warn("The size of the model is already fixed; ignoring the argument.")
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6)
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6
)
if pretrained:
pretrained_backbone = False
......@@ -199,8 +235,15 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers,
norm_layer, reduced_tail=reduce_tail, **kwargs)
backbone = _mobilenet_extractor(
"mobilenet_v3_large",
progress,
pretrained_backbone,
trainable_backbone_layers,
norm_layer,
reduced_tail=reduce_tail,
**kwargs,
)
size = (320, 320)
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
......@@ -219,11 +262,17 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
"image_std": [0.5, 0.5, 0.5],
}
kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, size, num_classes,
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs)
model = SSD(
backbone,
anchor_generator,
size,
num_classes,
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
**kwargs,
)
if pretrained:
weights_name = 'ssdlite320_mobilenet_v3_large_coco'
weights_name = "ssdlite320_mobilenet_v3_large_coco"
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
......
import math
from typing import List, Tuple, Dict, Optional
import torch
import torchvision
from torch import nn, Tensor
from typing import List, Tuple, Dict, Optional
from .image_list import ImageList
from .roi_heads import paste_masks_in_image
......@@ -12,6 +12,7 @@ from .roi_heads import paste_masks_in_image
@torch.jit.unused
def _get_shape_onnx(image: Tensor) -> Tensor:
from torch.onnx import operators
return operators.shape_as_tensor(image)[-2:]
......@@ -21,10 +22,13 @@ def _fake_cast_onnx(v: Tensor) -> float:
return v
def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: float,
target: Optional[Dict[str, Tensor]] = None,
fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def _resize_image_and_masks(
image: Tensor,
self_min_size: float,
self_max_size: float,
target: Optional[Dict[str, Tensor]] = None,
fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torchvision._is_tracing():
im_shape = _get_shape_onnx(image)
else:
......@@ -46,16 +50,23 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size:
scale_factor = scale.item()
recompute_scale_factor = True
image = torch.nn.functional.interpolate(image[None], size=size, scale_factor=scale_factor, mode='bilinear',
recompute_scale_factor=recompute_scale_factor, align_corners=False)[0]
image = torch.nn.functional.interpolate(
image[None],
size=size,
scale_factor=scale_factor,
mode="bilinear",
recompute_scale_factor=recompute_scale_factor,
align_corners=False,
)[0]
if target is None:
return image, target
if "masks" in target:
mask = target["masks"]
mask = torch.nn.functional.interpolate(mask[:, None].float(), size=size, scale_factor=scale_factor,
recompute_scale_factor=recompute_scale_factor)[:, 0].byte()
mask = torch.nn.functional.interpolate(
mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
)[:, 0].byte()
target["masks"] = mask
return image, target
......@@ -72,8 +83,15 @@ class GeneralizedRCNNTransform(nn.Module):
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
"""
def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float],
size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None):
def __init__(
self,
min_size: int,
max_size: int,
image_mean: List[float],
image_std: List[float],
size_divisible: int = 32,
fixed_size: Optional[Tuple[int, int]] = None,
):
super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
......@@ -84,10 +102,9 @@ class GeneralizedRCNNTransform(nn.Module):
self.size_divisible = size_divisible
self.fixed_size = fixed_size
def forward(self,
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
def forward(
self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
images = [img for img in images]
if targets is not None:
# make a copy of targets to avoid modifying it in-place
......@@ -106,8 +123,9 @@ class GeneralizedRCNNTransform(nn.Module):
target_index = targets[i] if targets is not None else None
if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors "
"of shape [C, H, W], got {}".format(image.shape))
raise ValueError(
"images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape)
)
image = self.normalize(image)
image, target_index = self.resize(image, target_index)
images[i] = image
......@@ -141,13 +159,14 @@ class GeneralizedRCNNTransform(nn.Module):
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed.
"""
index = int(torch.empty(1).uniform_(0., float(len(k))).item())
index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
return k[index]
def resize(self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
def resize(
self,
image: Tensor,
target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
h, w = image.shape[-2:]
if self.training:
size = float(self.torch_choice(self.min_size))
......@@ -220,11 +239,12 @@ class GeneralizedRCNNTransform(nn.Module):
return batched_imgs
def postprocess(self,
result: List[Dict[str, Tensor]],
image_shapes: List[Tuple[int, int]],
original_image_sizes: List[Tuple[int, int]]
) -> List[Dict[str, Tensor]]:
def postprocess(
self,
result: List[Dict[str, Tensor]],
image_shapes: List[Tuple[int, int]],
original_image_sizes: List[Tuple[int, int]],
) -> List[Dict[str, Tensor]]:
if self.training:
return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
......@@ -242,19 +262,20 @@ class GeneralizedRCNNTransform(nn.Module):
return result
def __repr__(self) -> str:
format_string = self.__class__.__name__ + '('
_indent = '\n '
format_string = self.__class__.__name__ + "("
_indent = "\n "
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size,
self.max_size)
format_string += '\n)'
format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(
_indent, self.min_size, self.max_size
)
format_string += "\n)"
return format_string
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
torch.tensor(s, dtype=torch.float32, device=keypoints.device)
/ torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
for s, s_orig in zip(new_size, original_size)
]
ratio_h, ratio_w = ratios
......@@ -271,8 +292,8 @@ def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
torch.tensor(s, dtype=torch.float32, device=boxes.device)
/ torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
for s, s_orig in zip(new_size, original_size)
]
ratio_height, ratio_width = ratios
......
import copy
import math
import torch
from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
from torchvision.ops import StochasticDepth
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible
from torchvision.ops import StochasticDepth
__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3",
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"]
__all__ = [
"EfficientNet",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
]
model_urls = {
......@@ -32,10 +41,17 @@ model_urls = {
class MBConvConfig:
# Stores information listed at Table 1 of the EfficientNet paper
def __init__(self,
expand_ratio: float, kernel: int, stride: int,
input_channels: int, out_channels: int, num_layers: int,
width_mult: float, depth_mult: float) -> None:
def __init__(
self,
expand_ratio: float,
kernel: int,
stride: int,
input_channels: int,
out_channels: int,
num_layers: int,
width_mult: float,
depth_mult: float,
) -> None:
self.expand_ratio = expand_ratio
self.kernel = kernel
self.stride = stride
......@@ -44,14 +60,14 @@ class MBConvConfig:
self.num_layers = self.adjust_depth(num_layers, depth_mult)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'expand_ratio={expand_ratio}'
s += ', kernel={kernel}'
s += ', stride={stride}'
s += ', input_channels={input_channels}'
s += ', out_channels={out_channels}'
s += ', num_layers={num_layers}'
s += ')'
s = self.__class__.__name__ + "("
s += "expand_ratio={expand_ratio}"
s += ", kernel={kernel}"
s += ", stride={stride}"
s += ", input_channels={input_channels}"
s += ", out_channels={out_channels}"
s += ", num_layers={num_layers}"
s += ")"
return s.format(**self.__dict__)
@staticmethod
......@@ -64,12 +80,17 @@ class MBConvConfig:
class MBConv(nn.Module):
def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None:
def __init__(
self,
cnf: MBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation,
) -> None:
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
......@@ -79,21 +100,39 @@ class MBConv(nn.Module):
# expand
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise
layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel,
stride=cnf.stride, groups=expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvNormActivation(
expanded_channels,
expanded_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
groups=expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
# project
layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=None))
layers.append(
ConvNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
......@@ -109,14 +148,14 @@ class MBConv(nn.Module):
class EfficientNet(nn.Module):
def __init__(
self,
inverted_residual_setting: List[MBConvConfig],
dropout: float,
stochastic_depth_prob: float = 0.2,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any
self,
inverted_residual_setting: List[MBConvConfig],
dropout: float,
stochastic_depth_prob: float = 0.2,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
"""
EfficientNet main class
......@@ -133,8 +172,10 @@ class EfficientNet(nn.Module):
if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and
all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])):
elif not (
isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
if block is None:
......@@ -147,8 +188,11 @@ class EfficientNet(nn.Module):
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.SiLU))
layers.append(
ConvNormActivation(
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
)
)
# building inverted residual blocks
total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting])
......@@ -175,8 +219,15 @@ class EfficientNet(nn.Module):
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.SiLU))
layers.append(
ConvNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.SiLU,
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
......@@ -187,7 +238,7 @@ class EfficientNet(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
......@@ -232,7 +283,7 @@ def _efficientnet_model(
dropout: float,
pretrained: bool,
progress: bool,
**kwargs: Any
**kwargs: Any,
) -> EfficientNet:
model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
if pretrained:
......@@ -318,8 +369,15 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A
progress (bool): If True, displays a progress bar of the download to stderr
"""
inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs)
return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs)
return _efficientnet_model(
"efficientnet_b5",
inverted_residual_setting,
0.4,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
......@@ -332,8 +390,15 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A
progress (bool): If True, displays a progress bar of the download to stderr
"""
inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs)
return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs)
return _efficientnet_model(
"efficientnet_b6",
inverted_residual_setting,
0.5,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
......@@ -346,5 +411,12 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A
progress (bool): If True, displays a progress bar of the download to stderr
"""
inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs)
return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs)
return _efficientnet_model(
"efficientnet_b7",
inverted_residual_setting,
0.5,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
from typing import Dict, Callable, List, Union, Optional, Tuple
from collections import OrderedDict
import warnings
import re
import warnings
from collections import OrderedDict
from copy import deepcopy
from itertools import chain
from typing import Dict, Callable, List, Union, Optional, Tuple
import torch
from torch import nn
from torch import fx
from torch import nn
from torch.fx.graph_module import _copy_attr
__all__ = ['create_feature_extractor', 'get_graph_node_names']
__all__ = ["create_feature_extractor", "get_graph_node_names"]
class LeafModuleAwareTracer(fx.Tracer):
......@@ -20,10 +20,11 @@ class LeafModuleAwareTracer(fx.Tracer):
modules that are not to be traced through. The resulting graph ends up
having single nodes referencing calls to the leaf modules' forward methods.
"""
def __init__(self, *args, **kwargs):
self.leaf_modules = {}
if 'leaf_modules' in kwargs:
leaf_modules = kwargs.pop('leaf_modules')
if "leaf_modules" in kwargs:
leaf_modules = kwargs.pop("leaf_modules")
self.leaf_modules = leaf_modules
super(LeafModuleAwareTracer, self).__init__(*args, **kwargs)
......@@ -51,10 +52,11 @@ class NodePathTracer(LeafModuleAwareTracer):
- When a duplicate node name is encountered, a suffix of the form
_{int} is added. The counter starts from 1.
"""
def __init__(self, *args, **kwargs):
super(NodePathTracer, self).__init__(*args, **kwargs)
# Track the qualified name of the Node being traced
self.current_module_qualname = ''
self.current_module_qualname = ""
# A map from FX Node to the qualified name\#
# NOTE: This is loosely like the "qualified name" mentioned in the
# torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
......@@ -78,32 +80,31 @@ class NodePathTracer(LeafModuleAwareTracer):
if not self.is_leaf_module(m, module_qualname):
out = forward(*args, **kwargs)
return out
return self.create_proxy('call_module', module_qualname, args, kwargs)
return self.create_proxy("call_module", module_qualname, args, kwargs)
finally:
self.current_module_qualname = old_qualname
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs,
name=None, type_expr=None, *_) -> fx.proxy.Proxy:
def create_proxy(
self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_
) -> fx.proxy.Proxy:
"""
Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified
name in `node_to_qualname`
"""
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_qualname[proxy.node] = self._get_node_qualname(
self.current_module_qualname, proxy.node)
self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node)
return proxy
def _get_node_qualname(
self, module_qualname: str, node: fx.node.Node) -> str:
def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname
if node.op != 'call_module':
if node.op != "call_module":
# In this case module_qualname from torch.fx doesn't go all the
# way to the leaf function/op so we need to append it
if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module
node_qualname += '.'
node_qualname += "."
node_qualname += str(node)
# Now we need to add an _{index} postfix on any repeated node names
......@@ -111,23 +112,22 @@ class NodePathTracer(LeafModuleAwareTracer):
# But for anything else, torch.fx already has a globally scoped
# _{index} postfix. But we want it locally (relative to direct parent)
# scoped. So first we need to undo the torch.fx postfix
if re.match(r'.+_[0-9]+$', node_qualname) is not None:
node_qualname = node_qualname.rsplit('_', 1)[0]
if re.match(r".+_[0-9]+$", node_qualname) is not None:
node_qualname = node_qualname.rsplit("_", 1)[0]
# ... and now we add on our own postfix
for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int}
if re.match(rf'{node_qualname}(_[0-9]+)?$',
existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, '')
if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None:
postfix = existing_qualname.replace(node_qualname, "")
if len(postfix):
# existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1
else:
# existing_qualname is of the form {node_qualname}
next_index = 1
node_qualname += f'_{next_index}'
node_qualname += f"_{next_index}"
break
return node_qualname
......@@ -141,8 +141,7 @@ def _is_subseq(x, y):
return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
def _warn_graph_differences(
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
"""
Utility function for warning the user if there are differences between
the train graph nodes and the eval graph nodes.
......@@ -150,29 +149,32 @@ def _warn_graph_differences(
train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and all(
t == e for t, e in zip(train_nodes, eval_nodes)):
if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)):
return
suggestion_msg = (
"When choosing nodes for feature extraction, you may need to specify "
"output nodes for train and eval mode separately.")
"output nodes for train and eval mode separately."
)
if _is_subseq(train_nodes, eval_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. ")
msg = (
"NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. "
)
elif _is_subseq(eval_nodes, train_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. ")
msg = (
"NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. "
)
else:
msg = ("The nodes obtained by tracing the model in train mode "
"are different to those obtained in eval mode. ")
msg = "The nodes obtained by tracing the model in train mode " "are different to those obtained in eval mode. "
warnings.warn(msg + suggestion_msg)
def get_graph_node_names(
model: nn.Module, tracer_kwargs: Dict = {},
suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]:
model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False
) -> Tuple[List[str], List[str]]:
"""
Dev utility to return node names in order of execution. See note on node
names under :func:`create_feature_extractor`. Useful for seeing which node
......@@ -230,11 +232,10 @@ class DualGraphModule(fx.GraphModule):
- Copies submodules according to the nodes of both train and eval graphs.
- Calling train(mode) switches between train graph and eval graph.
"""
def __init__(self,
root: torch.nn.Module,
train_graph: fx.Graph,
eval_graph: fx.Graph,
class_name: str = 'GraphModule'):
def __init__(
self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
):
"""
Args:
root (nn.Module): module from which the copied module hierarchy is
......@@ -252,7 +253,7 @@ class DualGraphModule(fx.GraphModule):
# Copy all get_attr and call_module ops (indicated by BOTH train and
# eval graphs)
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
if node.op in ['get_attr', 'call_module']:
if node.op in ["get_attr", "call_module"]:
assert isinstance(node.target, str)
_copy_attr(root, self, node.target)
......@@ -266,10 +267,11 @@ class DualGraphModule(fx.GraphModule):
# Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization.
assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \
"Train mode and eval mode should use the same tracer class"
assert (
self.eval_graph._tracer_cls == self.train_graph._tracer_cls
), "Train mode and eval mode should use the same tracer class"
self._tracer_cls = None
if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__:
self._tracer_cls = self.graph._tracer_cls
def train(self, mode=True):
......@@ -288,12 +290,13 @@ class DualGraphModule(fx.GraphModule):
def create_feature_extractor(
model: nn.Module,
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {},
suppress_diff_warning: bool = False) -> fx.GraphModule:
model: nn.Module,
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {},
suppress_diff_warning: bool = False,
) -> fx.GraphModule:
"""
Creates a new graph module that returns intermediate nodes from a given
model as dictionary with user specified keys as strings, and the requested
......@@ -396,18 +399,17 @@ def create_feature_extractor(
"""
is_training = model.training
assert any(arg is not None for arg in [
return_nodes, train_return_nodes, eval_return_nodes]), (
"Either `return_nodes` or `train_return_nodes` and "
"`eval_return_nodes` together, should be specified")
assert any(arg is not None for arg in [return_nodes, train_return_nodes, eval_return_nodes]), (
"Either `return_nodes` or `train_return_nodes` and " "`eval_return_nodes` together, should be specified"
)
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \
("If any of `train_return_nodes` and `eval_return_nodes` are "
"specified, then both should be specified")
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), (
"If any of `train_return_nodes` and `eval_return_nodes` are " "specified, then both should be specified"
)
assert ((return_nodes is None) ^ (train_return_nodes is None)), \
("If `train_return_nodes` and `eval_return_nodes` are specified, "
"then both should be specified")
assert (return_nodes is None) ^ (train_return_nodes is None), (
"If `train_return_nodes` and `eval_return_nodes` are specified, " "then both should be specified"
)
# Put *_return_nodes into Dict[str, str] format
def to_strdict(n) -> Dict[str, str]:
......@@ -426,45 +428,42 @@ def create_feature_extractor(
# Repeat the tracing and graph rewriting for train and eval mode
tracers = {}
graphs = {}
mode_return_nodes: Dict[str, Dict[str, str]] = {
'train': train_return_nodes,
'eval': eval_return_nodes
}
for mode in ['train', 'eval']:
if mode == 'train':
mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes}
for mode in ["train", "eval"]:
if mode == "train":
model.train()
elif mode == 'eval':
elif mode == "eval":
model.eval()
# Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model)
name = model.__class__.__name__ if isinstance(
model, nn.Module) else model.__name__
name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name)
available_nodes = list(tracer.node_to_qualname.values())
# FIXME We don't know if we should expect this to happen
assert len(set(available_nodes)) == len(available_nodes), \
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
assert len(set(available_nodes)) == len(
available_nodes
), "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
# Check that all outputs in return_nodes are present in the model
for query in mode_return_nodes[mode].keys():
# To check if a query is available we need to check that at least
# one of the available names starts with it up to a .
if not any([re.match(rf'^{query}(\.|$)', n) is not None
for n in available_nodes]):
if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]):
raise ValueError(
f"node: '{query}' is not present in model. Hint: use "
"`get_graph_node_names` to make sure the "
"`return_nodes` you specified are present. It may even "
"be that you need to specify `train_return_nodes` and "
"`eval_return_nodes` separately.")
"`eval_return_nodes` separately."
)
# Remove existing output nodes (train mode)
orig_output_nodes = []
for n in reversed(graph_module.graph.nodes):
if n.op == 'output':
if n.op == "output":
orig_output_nodes.append(n)
assert len(orig_output_nodes)
for n in orig_output_nodes:
......@@ -482,8 +481,8 @@ def create_feature_extractor(
# - When packing outputs into a named tuple like in InceptionV3
continue
for query in mode_return_nodes[mode]:
depth = query.count('.')
if '.'.join(module_qualname.split('.')[:depth + 1]) == query:
depth = query.count(".")
if ".".join(module_qualname.split(".")[: depth + 1]) == query:
output_nodes[mode_return_nodes[mode][query]] = n
mode_return_nodes[mode].pop(query)
break
......@@ -504,11 +503,10 @@ def create_feature_extractor(
# Warn user if there are any discrepancies between the graphs of the
# train and eval modes
if not suppress_diff_warning:
_warn_graph_differences(tracers['train'], tracers['eval'])
_warn_graph_differences(tracers["train"], tracers["eval"])
# Build the final graph module
graph_module = DualGraphModule(
model, graphs['train'], graphs['eval'], class_name=name)
graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name)
# Restore original training mode
model.train(is_training)
......
import warnings
from collections import namedtuple
from typing import Optional, Tuple, List, Callable, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from typing import Optional, Tuple, List, Callable, Any
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]
__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"]
model_urls = {
# GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
}
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
'aux_logits1': Optional[Tensor]}
GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
......@@ -37,19 +38,19 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
was trained on ImageNet. Default: *False*
"""
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False
if kwargs['aux_logits']:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
'so make sure to train them')
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = GoogLeNet(**kwargs)
state_dict = load_state_dict_from_url(model_urls['googlenet'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
......@@ -61,7 +62,7 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']
__constants__ = ["aux_logits", "transform_input"]
def __init__(
self,
......@@ -69,15 +70,18 @@ class GoogLeNet(nn.Module):
aux_logits: bool = True,
transform_input: bool = False,
init_weights: Optional[bool] = None,
blocks: Optional[List[Callable[..., nn.Module]]] = None
blocks: Optional[List[Callable[..., nn.Module]]] = None,
) -> None:
super(GoogLeNet, self).__init__()
if blocks is None:
blocks = [BasicConv2d, Inception, InceptionAux]
if init_weights is None:
warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
warnings.warn(
"The default weight initialization of GoogleNet will be changed in future releases of "
"torchvision. If you wish to keep the old behavior (which leads to long initialization times"
" due to scipy/scipy#11299), please set init_weights=True.",
FutureWarning,
)
init_weights = True
assert len(blocks) == 3
conv_block = blocks[0]
......@@ -197,7 +201,7 @@ class GoogLeNet(nn.Module):
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
else:
return x # type: ignore[return-value]
return x # type: ignore[return-value]
def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x)
......@@ -212,7 +216,6 @@ class GoogLeNet(nn.Module):
class Inception(nn.Module):
def __init__(
self,
in_channels: int,
......@@ -222,7 +225,7 @@ class Inception(nn.Module):
ch5x5red: int,
ch5x5: int,
pool_proj: int,
conv_block: Optional[Callable[..., nn.Module]] = None
conv_block: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(Inception, self).__init__()
if conv_block is None:
......@@ -230,20 +233,19 @@ class Inception(nn.Module):
self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
self.branch2 = nn.Sequential(
conv_block(in_channels, ch3x3red, kernel_size=1),
conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
)
self.branch3 = nn.Sequential(
conv_block(in_channels, ch5x5red, kernel_size=1),
# Here, kernel_size=3 instead of kernel_size=5 is a known bug.
# Please see https://github.com/pytorch/vision/issues/906 for details.
conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
conv_block(in_channels, pool_proj, kernel_size=1)
conv_block(in_channels, pool_proj, kernel_size=1),
)
def _forward(self, x: Tensor) -> List[Tensor]:
......@@ -261,12 +263,8 @@ class Inception(nn.Module):
class InceptionAux(nn.Module):
def __init__(
self,
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionAux, self).__init__()
if conv_block is None:
......@@ -295,13 +293,7 @@ class InceptionAux(nn.Module):
class BasicConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
**kwargs: Any
) -> None:
def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
......
from collections import namedtuple
import warnings
from collections import namedtuple
from typing import Callable, Any, Optional, Tuple, List
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"]
model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
}
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
......@@ -41,17 +43,16 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any)
was trained on ImageNet. Default: *False*
"""
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" in kwargs:
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
else:
original_aux_logits = True
kwargs['init_weights'] = False # we are loading weights from a pretrained model
kwargs["init_weights"] = False # we are loading weights from a pretrained model
model = Inception3(**kwargs)
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
......@@ -62,25 +63,24 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any)
class Inception3(nn.Module):
def __init__(
self,
num_classes: int = 1000,
aux_logits: bool = True,
transform_input: bool = False,
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
init_weights: Optional[bool] = None
init_weights: Optional[bool] = None,
) -> None:
super(Inception3, self).__init__()
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
if init_weights is None:
warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
warnings.warn(
"The default weight initialization of inception_v3 will be changed in future releases of "
"torchvision. If you wish to keep the old behavior (which leads to long initialization times"
" due to scipy/scipy#11299), please set init_weights=True.",
FutureWarning,
)
init_weights = True
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
......@@ -120,7 +120,7 @@ class Inception3(nn.Module):
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = float(m.stddev) if hasattr(m, 'stddev') else 0.1 # type: ignore
stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
......@@ -208,12 +208,8 @@ class Inception3(nn.Module):
class InceptionA(nn.Module):
def __init__(
self,
in_channels: int,
pool_features: int,
conv_block: Optional[Callable[..., nn.Module]] = None
self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionA, self).__init__()
if conv_block is None:
......@@ -251,12 +247,7 @@ class InceptionA(nn.Module):
class InceptionB(nn.Module):
def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
super(InceptionB, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
......@@ -284,12 +275,8 @@ class InceptionB(nn.Module):
class InceptionC(nn.Module):
def __init__(
self,
in_channels: int,
channels_7x7: int,
conv_block: Optional[Callable[..., nn.Module]] = None
self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionC, self).__init__()
if conv_block is None:
......@@ -334,12 +321,7 @@ class InceptionC(nn.Module):
class InceptionD(nn.Module):
def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
super(InceptionD, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
......@@ -370,12 +352,7 @@ class InceptionD(nn.Module):
class InceptionE(nn.Module):
def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
super(InceptionE, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
......@@ -422,12 +399,8 @@ class InceptionE(nn.Module):
class InceptionAux(nn.Module):
def __init__(
self,
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionAux, self).__init__()
if conv_block is None:
......@@ -457,13 +430,7 @@ class InceptionAux(nn.Module):
class BasicConv2d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
**kwargs: Any
) -> None:
def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
......
import warnings
from typing import Any, Dict, List
import torch
from torch import Tensor
import torch.nn as nn
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from typing import Any, Dict, List
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']
__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"]
_MODEL_URLS = {
"mnasnet0_5":
"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_75": None,
"mnasnet1_0":
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
"mnasnet1_3": None
"mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
"mnasnet1_3": None,
}
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
......@@ -23,34 +22,27 @@ _BN_MOMENTUM = 1 - 0.9997
class _InvertedResidual(nn.Module):
def __init__(
self,
in_ch: int,
out_ch: int,
kernel_size: int,
stride: int,
expansion_factor: int,
bn_momentum: float = 0.1
self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
) -> None:
super(_InvertedResidual, self).__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor
self.apply_residual = (in_ch == out_ch and stride == 1)
self.apply_residual = in_ch == out_ch and stride == 1
self.layers = nn.Sequential(
# Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
stride=stride, groups=mid_ch, bias=False),
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum))
nn.BatchNorm2d(out_ch, momentum=bn_momentum),
)
def forward(self, input: Tensor) -> Tensor:
if self.apply_residual:
......@@ -59,39 +51,37 @@ class _InvertedResidual(nn.Module):
return self.layers(input)
def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int,
bn_momentum: float) -> nn.Sequential:
""" Creates a stack of inverted residuals. """
def _stack(
in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
) -> nn.Sequential:
"""Creates a stack of inverted residuals."""
assert repeats >= 1
# First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
bn_momentum=bn_momentum)
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
remaining = []
for _ in range(1, repeats):
remaining.append(
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
bn_momentum=bn_momentum))
remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining)
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
""" Asymmetric rounding to make `val` divisible by `divisor`. With default
"""Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor
def _get_depths(alpha: float) -> List[int]:
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
"""Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down."""
depths = [32, 16, 24, 40, 80, 96, 192, 320]
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
"""MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
>>> model = MNASNet(1.0, num_classes=1000)
>>> x = torch.rand(1, 3, 224, 224)
......@@ -101,15 +91,11 @@ class MNASNet(torch.nn.Module):
>>> y.nelement()
1000
"""
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2
def __init__(
self,
alpha: float,
num_classes: int = 1000,
dropout: float = 0.2
) -> None:
def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
super(MNASNet, self).__init__()
assert alpha > 0.0
self.alpha = alpha
......@@ -121,8 +107,7 @@ class MNASNet(torch.nn.Module):
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
groups=depths[0], bias=False),
nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
......@@ -140,8 +125,7 @@ class MNASNet(torch.nn.Module):
nn.ReLU(inplace=True),
]
self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
nn.Linear(1280, num_classes))
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
self._initialize_weights()
def forward(self, x: Tensor) -> Tensor:
......@@ -153,20 +137,26 @@ class MNASNet(torch.nn.Module):
def _initialize_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out",
nonlinearity="relu")
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode="fan_out",
nonlinearity="sigmoid")
nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
nn.init.zeros_(m.bias)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool,
missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None:
def _load_from_state_dict(
self,
state_dict: Dict,
prefix: str,
local_metadata: Dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
version = local_metadata.get("version", None)
assert version in [1, 2]
......@@ -180,8 +170,7 @@ class MNASNet(torch.nn.Module):
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
bias=False),
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
......@@ -199,20 +188,19 @@ class MNASNet(torch.nn.Module):
"This checkpoint will load and work as before, but "
"you may want to upgrade by training a newer model or "
"transfer learning from an updated ImageNet checkpoint.",
UserWarning)
UserWarning,
)
super(MNASNet, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
raise ValueError("No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(
load_state_dict_from_url(checkpoint_url, progress=progress))
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
......
import torch
import warnings
from functools import partial
from torch import nn
from typing import Callable, Any, Optional, List
import torch
from torch import Tensor
from torch import nn
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation
from ._utils import _make_divisible
from typing import Callable, Any, Optional, List
__all__ = ['MobileNetV2', 'mobilenet_v2']
__all__ = ["MobileNetV2", "mobilenet_v2"]
model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
}
......@@ -23,7 +24,9 @@ class _DeprecatedConvBNAct(ConvNormActivation):
def __init__(self, *args, **kwargs):
warnings.warn(
"The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning)
"Use torchvision.ops.misc.ConvNormActivation instead.",
FutureWarning,
)
if kwargs.get("norm_layer", None) is None:
kwargs["norm_layer"] = nn.BatchNorm2d
if kwargs.get("activation_layer", None) is None:
......@@ -37,12 +40,7 @@ ConvBNActivation = _DeprecatedConvBNAct
class InvertedResidual(nn.Module):
def __init__(
self,
inp: int,
oup: int,
stride: int,
expand_ratio: int,
norm_layer: Optional[Callable[..., nn.Module]] = None
self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InvertedResidual, self).__init__()
self.stride = stride
......@@ -57,16 +55,25 @@ class InvertedResidual(nn.Module):
layers: List[nn.Module] = []
if expand_ratio != 1:
# pw
layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
layers.extend([
# dw
ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer,
activation_layer=nn.ReLU6),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
])
layers.append(
ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
)
layers.extend(
[
# dw
ConvNormActivation(
hidden_dim,
hidden_dim,
stride=stride,
groups=hidden_dim,
norm_layer=norm_layer,
activation_layer=nn.ReLU6,
),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
]
)
self.conv = nn.Sequential(*layers)
self.out_channels = oup
self._is_cn = stride > 1
......@@ -86,7 +93,7 @@ class MobileNetV2(nn.Module):
inverted_residual_setting: Optional[List[List[int]]] = None,
round_nearest: int = 8,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
"""
MobileNet V2 main class
......@@ -126,14 +133,17 @@ class MobileNetV2(nn.Module):
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
raise ValueError(
"inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting)
)
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer,
activation_layer=nn.ReLU6)]
features: List[nn.Module] = [
ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
]
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
......@@ -142,8 +152,11 @@ class MobileNetV2(nn.Module):
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel
# building last several layers
features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer,
activation_layer=nn.ReLU6))
features.append(
ConvNormActivation(
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
)
)
# make it nn.Sequential
self.features = nn.Sequential(*features)
......@@ -156,7 +169,7 @@ class MobileNetV2(nn.Module):
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
......@@ -191,7 +204,6 @@ def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any)
"""
model = MobileNetV2(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress)
model.load_state_dict(state_dict)
return model
import warnings
import torch
from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
from ._utils import _make_divisible
......@@ -20,22 +20,34 @@ model_urls = {
class SqueezeExcitation(SElayer):
"""DEPRECATED
"""
"""DEPRECATED"""
def __init__(self, input_channels: int, squeeze_factor: int = 4):
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = self.activation
delattr(self, 'activation')
delattr(self, "activation")
warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning)
"Use torchvision.ops.misc.SqueezeExcitation instead.",
FutureWarning,
)
class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool,
activation: str, stride: int, dilation: int, width_mult: float):
def __init__(
self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float,
):
self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
......@@ -52,11 +64,15 @@ class InvertedResidualConfig:
class InvertedResidual(nn.Module):
# Implemented as described at section 5 of MobileNetV3 paper
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)):
def __init__(
self,
cnf: InvertedResidualConfig,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
......@@ -65,21 +81,40 @@ class InvertedResidual(nn.Module):
# expand
if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel,
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
layers.append(
ConvNormActivation(
cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
activation_layer=None))
layers.append(
ConvNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels
......@@ -93,15 +128,14 @@ class InvertedResidual(nn.Module):
class MobileNetV3(nn.Module):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any
self,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
) -> None:
"""
MobileNet V3 main class
......@@ -117,8 +151,10 @@ class MobileNetV3(nn.Module):
if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])):
elif not (
isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
if block is None:
......@@ -131,8 +167,16 @@ class MobileNetV3(nn.Module):
# building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer,
activation_layer=nn.Hardswish))
layers.append(
ConvNormActivation(
3,
firstconv_output_channels,
kernel_size=3,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
)
# building inverted residual blocks
for cnf in inverted_residual_setting:
......@@ -141,8 +185,15 @@ class MobileNetV3(nn.Module):
# building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1,
norm_layer=norm_layer, activation_layer=nn.Hardswish))
layers.append(
ConvNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
)
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1)
......@@ -155,7 +206,7 @@ class MobileNetV3(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
......@@ -179,8 +230,9 @@ class MobileNetV3(nn.Module):
return self._forward_impl(x)
def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False,
**kwargs: Any):
def _mobilenet_v3_conf(
arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
):
reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1
......@@ -233,7 +285,7 @@ def _mobilenet_v3_model(
last_channel: int,
pretrained: bool,
progress: bool,
**kwargs: Any
**kwargs: Any,
):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained:
......
import warnings
from typing import Any
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Any
from torch import Tensor
from torch.nn import functional as F
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.googlenet import (
GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls)
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableGoogLeNet', 'googlenet']
__all__ = ["QuantizableGoogLeNet", "googlenet"]
quant_model_urls = {
# fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
'googlenet_fbgemm': 'https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth',
"googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
}
......@@ -44,35 +43,35 @@ def googlenet(
was trained on ImageNet. Default: *False*
"""
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' not in kwargs:
kwargs['aux_logits'] = False
if kwargs['aux_logits']:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
'so make sure to train them')
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = True
kwargs['init_weights'] = False
if "transform_input" not in kwargs:
kwargs["transform_input"] = True
if "aux_logits" not in kwargs:
kwargs["aux_logits"] = False
if kwargs["aux_logits"]:
warnings.warn(
"auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them"
)
original_aux_logits = kwargs["aux_logits"]
kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls['googlenet' + '_' + backend]
model_url = quant_model_urls["googlenet" + "_" + backend]
else:
model_url = model_urls['googlenet']
model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url,
progress=progress)
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
......@@ -84,7 +83,6 @@ def googlenet(
class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()
......@@ -100,10 +98,10 @@ class QuantizableBasicConv2d(BasicConv2d):
class QuantizableInception(Inception):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInception, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, *args, **kwargs)
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.cat = nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor:
......@@ -115,9 +113,7 @@ class QuantizableInceptionAux(InceptionAux):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)
......@@ -144,9 +140,7 @@ class QuantizableGoogLeNet(GoogLeNet):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
*args,
**kwargs
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
......
import warnings
from typing import Any, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Any, List
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model
......@@ -20,8 +20,7 @@ __all__ = [
quant_model_urls = {
# fp32 weights ported from TensorFlow, quantized in PyTorch
"inception_v3_google_fbgemm":
"https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
"inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
}
......@@ -66,7 +65,7 @@ def inception_v3(
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
......@@ -76,12 +75,11 @@ def inception_v3(
if not original_aux_logits:
model.aux_logits = False
model.AuxLogits = None
model_url = quant_model_urls['inception_v3_google' + '_' + backend]
model_url = quant_model_urls["inception_v3_google" + "_" + backend]
else:
model_url = inception_module.model_urls['inception_v3_google']
model_url = inception_module.model_urls["inception_v3_google"]
state_dict = load_state_dict_from_url(model_url,
progress=progress)
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
......@@ -111,9 +109,7 @@ class QuantizableInceptionA(inception_module.InceptionA):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionA, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.myop = nn.quantized.FloatFunctional()
......@@ -126,9 +122,7 @@ class QuantizableInceptionB(inception_module.InceptionB):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionB, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.myop = nn.quantized.FloatFunctional()
......@@ -141,9 +135,7 @@ class QuantizableInceptionC(inception_module.InceptionC):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionC, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.myop = nn.quantized.FloatFunctional()
......@@ -156,9 +148,7 @@ class QuantizableInceptionD(inception_module.InceptionD):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionD, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.myop = nn.quantized.FloatFunctional()
......@@ -171,9 +161,7 @@ class QuantizableInceptionE(inception_module.InceptionE):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionE, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.myop1 = nn.quantized.FloatFunctional()
self.myop2 = nn.quantized.FloatFunctional()
......@@ -209,9 +197,7 @@ class QuantizableInceptionAux(inception_module.InceptionAux):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d,
*args,
**kwargs
conv_block=QuantizableBasicConv2d, *args, **kwargs
)
......@@ -233,8 +219,8 @@ class QuantizableInception3(inception_module.Inception3):
QuantizableInceptionC,
QuantizableInceptionD,
QuantizableInceptionE,
QuantizableInceptionAux
]
QuantizableInceptionAux,
],
)
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
......
from torch import nn
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from typing import Any
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from torch import Tensor
from torch import nn
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2']
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
quant_model_urls = {
'mobilenet_v2_qnnpack':
'https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth'
"mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"
}
......@@ -57,7 +55,7 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvNormActivation:
fuse_modules(m, ['0', '1', '2'], inplace=True)
fuse_modules(m, ["0", "1", "2"], inplace=True)
if type(m) == QuantizableInvertedResidual:
m.fuse_model()
......@@ -87,19 +85,18 @@ def mobilenet_v2(
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'qnnpack'
backend = "qnnpack"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls['mobilenet_v2_' + backend]
model_url = quant_model_urls["mobilenet_v2_" + backend]
else:
model_url = model_urls['mobilenet_v2']
model_url = model_urls["mobilenet_v2"]
state_dict = load_state_dict_from_url(model_url,
progress=progress)
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
from typing import Any, List, Optional
import torch
from torch import nn, Tensor
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\
model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
from .utils import _replace_relu
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large']
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
quant_model_urls = {
'mobilenet_v3_large_qnnpack':
"https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
"mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
}
......@@ -29,7 +29,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
return self.skip_mul.mul(self._scale(input), input)
def fuse_model(self) -> None:
fuse_modules(self, ['fc1', 'activation'], inplace=True)
fuse_modules(self, ["fc1", "activation"], inplace=True)
def _load_from_state_dict(
self,
......@@ -45,7 +45,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
if version is None or version < 2:
default_state_dict = {
"scale_activation.activation_post_process.scale": torch.tensor([1.]),
"scale_activation.activation_post_process.scale": torch.tensor([1.0]),
"scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
"scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
"scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
......@@ -69,11 +69,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
class QuantizableInvertedResidual(InvertedResidual):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__( # type: ignore[misc]
se_layer=QuantizableSqueezeExcitation,
*args,
**kwargs
)
super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs) # type: ignore[misc]
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor:
......@@ -104,20 +100,15 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvNormActivation:
modules_to_fuse = ['0', '1']
modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) == nn.ReLU:
modules_to_fuse.append('2')
modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation:
m.fuse_model()
def _load_weights(
arch: str,
model: QuantizableMobileNetV3,
model_url: Optional[str],
progress: bool
) -> None:
def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress)
......@@ -138,14 +129,14 @@ def _mobilenet_v3_model(
_replace_relu(model)
if quantize:
backend = 'qnnpack'
backend = "qnnpack"
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)
if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
_load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress)
torch.quantization.convert(model, inplace=True)
model.eval()
......
from typing import Any, Type, Union, List
import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn
from torch import Tensor
from typing import Any, Type, Union, List
from torch.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from torch.quantization import fuse_modules
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableResNet', 'resnet18', 'resnet50',
'resnext101_32x8d']
__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
quant_model_urls = {
'resnet18_fbgemm':
'https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth',
'resnet50_fbgemm':
'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth',
'resnext101_32x8d_fbgemm':
'https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth',
"resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
"resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
"resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
}
......@@ -45,10 +42,9 @@ class QuantizableBasicBlock(BasicBlock):
return out
def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
['conv2', 'bn2']], inplace=True)
torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
class QuantizableBottleneck(Bottleneck):
......@@ -77,15 +73,12 @@ class QuantizableBottleneck(Bottleneck):
return out
def fuse_model(self) -> None:
fuse_modules(self, [['conv1', 'bn1', 'relu1'],
['conv2', 'bn2', 'relu2'],
['conv3', 'bn3']], inplace=True)
fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)
torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
class QuantizableResNet(ResNet):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableResNet, self).__init__(*args, **kwargs)
......@@ -109,7 +102,7 @@ class QuantizableResNet(ResNet):
and the model after modification is in floating point
"""
fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)
fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
for m in self.modules():
if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
m.fuse_model()
......@@ -129,19 +122,18 @@ def _resnet(
_replace_relu(model)
if quantize:
# TODO use pretrained as a string to specify the backend
backend = 'fbgemm'
backend = "fbgemm"
quantize_model(model, backend)
else:
assert pretrained in [True, False]
if pretrained:
if quantize:
model_url = quant_model_urls[arch + '_' + backend]
model_url = quant_model_urls[arch + "_" + backend]
else:
model_url = model_urls[arch]
state_dict = load_state_dict_from_url(model_url,
progress=progress)
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
......@@ -161,8 +153,7 @@ def resnet18(
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
"""
return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress,
quantize, **kwargs)
return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs)
def resnet50(
......@@ -180,8 +171,7 @@ def resnet50(
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
"""
return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress,
quantize, **kwargs)
return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs)
def resnext101_32x8d(
......@@ -198,7 +188,6 @@ def resnext101_32x8d(
progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', QuantizableBottleneck, [3, 4, 23, 3],
pretrained, progress, quantize, **kwargs)
kwargs["groups"] = 32
kwargs["width_per_group"] = 8
return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment