Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import torch import torch
from torch import nn from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ._utils import overwrite_eps
from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .faster_rcnn import FasterRCNN
__all__ = [ __all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"]
"KeypointRCNN", "keypointrcnn_resnet50_fpn"
]
class KeypointRCNN(FasterRCNN): class KeypointRCNN(FasterRCNN):
...@@ -151,27 +147,47 @@ class KeypointRCNN(FasterRCNN): ...@@ -151,27 +147,47 @@ class KeypointRCNN(FasterRCNN):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x) >>> predictions = model(x)
""" """
def __init__(self, backbone, num_classes=None,
# transform parameters def __init__(
min_size=None, max_size=1333, self,
image_mean=None, image_std=None, backbone,
# RPN parameters num_classes=None,
rpn_anchor_generator=None, rpn_head=None, # transform parameters
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, min_size=None,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, max_size=1333,
rpn_nms_thresh=0.7, image_mean=None,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, image_std=None,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # RPN parameters
rpn_score_thresh=0.0, rpn_anchor_generator=None,
# Box parameters rpn_head=None,
box_roi_pool=None, box_head=None, box_predictor=None, rpn_pre_nms_top_n_train=2000,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, rpn_pre_nms_top_n_test=1000,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, rpn_post_nms_top_n_train=2000,
box_batch_size_per_image=512, box_positive_fraction=0.25, rpn_post_nms_top_n_test=1000,
bbox_reg_weights=None, rpn_nms_thresh=0.7,
# keypoint parameters rpn_fg_iou_thresh=0.7,
keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, rpn_bg_iou_thresh=0.3,
num_keypoints=17): rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
# keypoint parameters
keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
num_keypoints=17,
):
assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None)))
if min_size is None: if min_size is None:
...@@ -184,10 +200,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -184,10 +200,7 @@ class KeypointRCNN(FasterRCNN):
out_channels = backbone.out_channels out_channels = backbone.out_channels
if keypoint_roi_pool is None: if keypoint_roi_pool is None:
keypoint_roi_pool = MultiScaleRoIAlign( keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
featmap_names=['0', '1', '2', '3'],
output_size=14,
sampling_ratio=2)
if keypoint_head is None: if keypoint_head is None:
keypoint_layers = tuple(512 for _ in range(8)) keypoint_layers = tuple(512 for _ in range(8))
...@@ -198,24 +211,39 @@ class KeypointRCNN(FasterRCNN): ...@@ -198,24 +211,39 @@ class KeypointRCNN(FasterRCNN):
keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)
super(KeypointRCNN, self).__init__( super(KeypointRCNN, self).__init__(
backbone, num_classes, backbone,
num_classes,
# transform parameters # transform parameters
min_size, max_size, min_size,
image_mean, image_std, max_size,
image_mean,
image_std,
# RPN-specific parameters # RPN-specific parameters
rpn_anchor_generator, rpn_head, rpn_anchor_generator,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, rpn_head,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, rpn_pre_nms_top_n_train,
rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train,
rpn_post_nms_top_n_test,
rpn_nms_thresh, rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_fg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction, rpn_bg_iou_thresh,
rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_score_thresh, rpn_score_thresh,
# Box parameters # Box parameters
box_roi_pool, box_head, box_predictor, box_roi_pool,
box_score_thresh, box_nms_thresh, box_detections_per_img, box_head,
box_fg_iou_thresh, box_bg_iou_thresh, box_predictor,
box_batch_size_per_image, box_positive_fraction, box_score_thresh,
bbox_reg_weights) box_nms_thresh,
box_detections_per_img,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
)
self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
self.roi_heads.keypoint_head = keypoint_head self.roi_heads.keypoint_head = keypoint_head
...@@ -249,9 +277,7 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -249,9 +277,7 @@ class KeypointRCNNPredictor(nn.Module):
stride=2, stride=2,
padding=deconv_kernel // 2 - 1, padding=deconv_kernel // 2 - 1,
) )
nn.init.kaiming_normal_( nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu"
)
nn.init.constant_(self.kps_score_lowres.bias, 0) nn.init.constant_(self.kps_score_lowres.bias, 0)
self.up_scale = 2 self.up_scale = 2
self.out_channels = num_keypoints self.out_channels = num_keypoints
...@@ -265,16 +291,20 @@ class KeypointRCNNPredictor(nn.Module): ...@@ -265,16 +291,20 @@ class KeypointRCNNPredictor(nn.Module):
model_urls = { model_urls = {
# legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606 # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
'keypointrcnn_resnet50_fpn_coco_legacy': "keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth', "keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
'keypointrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth',
} }
def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, def keypointrcnn_resnet50_fpn(
num_classes=2, num_keypoints=17, pretrained=False,
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): progress=True,
num_classes=2,
num_keypoints=17,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
""" """
Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.
...@@ -331,19 +361,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -331,19 +361,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if pretrained: if pretrained:
key = 'keypointrcnn_resnet50_fpn_coco' key = "keypointrcnn_resnet50_fpn_coco"
if pretrained == 'legacy': if pretrained == "legacy":
key += '_legacy' key += "_legacy"
state_dict = load_state_dict_from_url(model_urls[key], state_dict = load_state_dict_from_url(model_urls[key], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
from collections import OrderedDict from collections import OrderedDict
from torch import nn from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ._utils import overwrite_eps
from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from .faster_rcnn import FasterRCNN
__all__ = [ __all__ = [
"MaskRCNN", "maskrcnn_resnet50_fpn", "MaskRCNN",
"maskrcnn_resnet50_fpn",
] ]
...@@ -149,26 +148,46 @@ class MaskRCNN(FasterRCNN): ...@@ -149,26 +148,46 @@ class MaskRCNN(FasterRCNN):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x) >>> predictions = model(x)
""" """
def __init__(self, backbone, num_classes=None,
# transform parameters def __init__(
min_size=800, max_size=1333, self,
image_mean=None, image_std=None, backbone,
# RPN parameters num_classes=None,
rpn_anchor_generator=None, rpn_head=None, # transform parameters
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, min_size=800,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, max_size=1333,
rpn_nms_thresh=0.7, image_mean=None,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, image_std=None,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, # RPN parameters
rpn_score_thresh=0.0, rpn_anchor_generator=None,
# Box parameters rpn_head=None,
box_roi_pool=None, box_head=None, box_predictor=None, rpn_pre_nms_top_n_train=2000,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, rpn_pre_nms_top_n_test=1000,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, rpn_post_nms_top_n_train=2000,
box_batch_size_per_image=512, box_positive_fraction=0.25, rpn_post_nms_top_n_test=1000,
bbox_reg_weights=None, rpn_nms_thresh=0.7,
# Mask parameters rpn_fg_iou_thresh=0.7,
mask_roi_pool=None, mask_head=None, mask_predictor=None): rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
# Mask parameters
mask_roi_pool=None,
mask_head=None,
mask_predictor=None,
):
assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None)))
...@@ -179,10 +198,7 @@ class MaskRCNN(FasterRCNN): ...@@ -179,10 +198,7 @@ class MaskRCNN(FasterRCNN):
out_channels = backbone.out_channels out_channels = backbone.out_channels
if mask_roi_pool is None: if mask_roi_pool is None:
mask_roi_pool = MultiScaleRoIAlign( mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
featmap_names=['0', '1', '2', '3'],
output_size=14,
sampling_ratio=2)
if mask_head is None: if mask_head is None:
mask_layers = (256, 256, 256, 256) mask_layers = (256, 256, 256, 256)
...@@ -192,28 +208,42 @@ class MaskRCNN(FasterRCNN): ...@@ -192,28 +208,42 @@ class MaskRCNN(FasterRCNN):
if mask_predictor is None: if mask_predictor is None:
mask_predictor_in_channels = 256 # == mask_layers[-1] mask_predictor_in_channels = 256 # == mask_layers[-1]
mask_dim_reduced = 256 mask_dim_reduced = 256
mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes)
mask_dim_reduced, num_classes)
super(MaskRCNN, self).__init__( super(MaskRCNN, self).__init__(
backbone, num_classes, backbone,
num_classes,
# transform parameters # transform parameters
min_size, max_size, min_size,
image_mean, image_std, max_size,
image_mean,
image_std,
# RPN-specific parameters # RPN-specific parameters
rpn_anchor_generator, rpn_head, rpn_anchor_generator,
rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, rpn_head,
rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, rpn_pre_nms_top_n_train,
rpn_pre_nms_top_n_test,
rpn_post_nms_top_n_train,
rpn_post_nms_top_n_test,
rpn_nms_thresh, rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_fg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction, rpn_bg_iou_thresh,
rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_score_thresh, rpn_score_thresh,
# Box parameters # Box parameters
box_roi_pool, box_head, box_predictor, box_roi_pool,
box_score_thresh, box_nms_thresh, box_detections_per_img, box_head,
box_fg_iou_thresh, box_bg_iou_thresh, box_predictor,
box_batch_size_per_image, box_positive_fraction, box_score_thresh,
bbox_reg_weights) box_nms_thresh,
box_detections_per_img,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
)
self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_roi_pool = mask_roi_pool
self.roi_heads.mask_head = mask_head self.roi_heads.mask_head = mask_head
...@@ -232,8 +262,8 @@ class MaskRCNNHeads(nn.Sequential): ...@@ -232,8 +262,8 @@ class MaskRCNNHeads(nn.Sequential):
next_feature = in_channels next_feature = in_channels
for layer_idx, layer_features in enumerate(layers, 1): for layer_idx, layer_features in enumerate(layers, 1):
d["mask_fcn{}".format(layer_idx)] = nn.Conv2d( d["mask_fcn{}".format(layer_idx)] = nn.Conv2d(
next_feature, layer_features, kernel_size=3, next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation
stride=1, padding=dilation, dilation=dilation) )
d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True) d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True)
next_feature = layer_features next_feature = layer_features
...@@ -247,11 +277,15 @@ class MaskRCNNHeads(nn.Sequential): ...@@ -247,11 +277,15 @@ class MaskRCNNHeads(nn.Sequential):
class MaskRCNNPredictor(nn.Sequential): class MaskRCNNPredictor(nn.Sequential):
def __init__(self, in_channels, dim_reduced, num_classes): def __init__(self, in_channels, dim_reduced, num_classes):
super(MaskRCNNPredictor, self).__init__(OrderedDict([ super(MaskRCNNPredictor, self).__init__(
("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), OrderedDict(
("relu", nn.ReLU(inplace=True)), [
("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)), ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)),
])) ("relu", nn.ReLU(inplace=True)),
("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)),
]
)
)
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if "weight" in name: if "weight" in name:
...@@ -261,13 +295,13 @@ class MaskRCNNPredictor(nn.Sequential): ...@@ -261,13 +295,13 @@ class MaskRCNNPredictor(nn.Sequential):
model_urls = { model_urls = {
'maskrcnn_resnet50_fpn_coco': "maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth',
} }
def maskrcnn_resnet50_fpn(pretrained=False, progress=True, def maskrcnn_resnet50_fpn(
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
""" """
Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. Constructs a Mask R-CNN model with a ResNet-50-FPN backbone.
...@@ -324,16 +358,16 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -324,16 +358,16 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = MaskRCNN(backbone, num_classes, **kwargs) model = MaskRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
import math import math
from collections import OrderedDict
import warnings import warnings
from collections import OrderedDict
from typing import Dict, List, Tuple, Optional
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from typing import Dict, List, Tuple, Optional
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
from ...ops.feature_pyramid_network import LastLevelP6P7
from . import _utils as det_utils from . import _utils as det_utils
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
from ...ops.feature_pyramid_network import LastLevelP6P7 from .transform import GeneralizedRCNNTransform
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
__all__ = [ __all__ = ["RetinaNet", "retinanet_resnet50_fpn"]
"RetinaNet", "retinanet_resnet50_fpn"
]
def _sum(x: List[Tensor]) -> Tensor: def _sum(x: List[Tensor]) -> Tensor:
...@@ -48,16 +45,13 @@ class RetinaNetHead(nn.Module): ...@@ -48,16 +45,13 @@ class RetinaNetHead(nn.Module):
def compute_loss(self, targets, head_outputs, anchors, matched_idxs): def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
return { return {
'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
} }
def forward(self, x): def forward(self, x):
# type: (List[Tensor]) -> Dict[str, Tensor] # type: (List[Tensor]) -> Dict[str, Tensor]
return { return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
'cls_logits': self.classification_head(x),
'bbox_regression': self.regression_head(x)
}
class RetinaNetClassificationHead(nn.Module): class RetinaNetClassificationHead(nn.Module):
...@@ -100,7 +94,7 @@ class RetinaNetClassificationHead(nn.Module): ...@@ -100,7 +94,7 @@ class RetinaNetClassificationHead(nn.Module):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
losses = [] losses = []
cls_logits = head_outputs['cls_logits'] cls_logits = head_outputs["cls_logits"]
for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
# determine only the foreground # determine only the foreground
...@@ -111,18 +105,21 @@ class RetinaNetClassificationHead(nn.Module): ...@@ -111,18 +105,21 @@ class RetinaNetClassificationHead(nn.Module):
gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target = torch.zeros_like(cls_logits_per_image)
gt_classes_target[ gt_classes_target[
foreground_idxs_per_image, foreground_idxs_per_image,
targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
] = 1.0 ] = 1.0
# find indices for which anchors should be ignored # find indices for which anchors should be ignored
valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
# compute the classification loss # compute the classification loss
losses.append(sigmoid_focal_loss( losses.append(
cls_logits_per_image[valid_idxs_per_image], sigmoid_focal_loss(
gt_classes_target[valid_idxs_per_image], cls_logits_per_image[valid_idxs_per_image],
reduction='sum', gt_classes_target[valid_idxs_per_image],
) / max(1, num_foreground)) reduction="sum",
)
/ max(1, num_foreground)
)
return _sum(losses) / len(targets) return _sum(losses) / len(targets)
...@@ -153,8 +150,9 @@ class RetinaNetRegressionHead(nn.Module): ...@@ -153,8 +150,9 @@ class RetinaNetRegressionHead(nn.Module):
in_channels (int): number of channels of the input feature in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted num_anchors (int): number of anchors to be predicted
""" """
__annotations__ = { __annotations__ = {
'box_coder': det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
} }
def __init__(self, in_channels, num_anchors): def __init__(self, in_channels, num_anchors):
...@@ -181,16 +179,17 @@ class RetinaNetRegressionHead(nn.Module): ...@@ -181,16 +179,17 @@ class RetinaNetRegressionHead(nn.Module):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
losses = [] losses = []
bbox_regression = head_outputs['bbox_regression'] bbox_regression = head_outputs["bbox_regression"]
for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
zip(targets, bbox_regression, anchors, matched_idxs): targets, bbox_regression, anchors, matched_idxs
):
# determine only the foreground indices, ignore the rest # determine only the foreground indices, ignore the rest
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
num_foreground = foreground_idxs_per_image.numel() num_foreground = foreground_idxs_per_image.numel()
# select only the foreground boxes # select only the foreground boxes
matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
...@@ -198,11 +197,10 @@ class RetinaNetRegressionHead(nn.Module): ...@@ -198,11 +197,10 @@ class RetinaNetRegressionHead(nn.Module):
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
# compute the loss # compute the loss
losses.append(torch.nn.functional.l1_loss( losses.append(
bbox_regression_per_image, torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
target_regression, / max(1, num_foreground)
reduction='sum' )
) / max(1, num_foreground))
return _sum(losses) / max(1, len(targets)) return _sum(losses) / max(1, len(targets))
...@@ -309,30 +307,40 @@ class RetinaNet(nn.Module): ...@@ -309,30 +307,40 @@ class RetinaNet(nn.Module):
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x) >>> predictions = model(x)
""" """
__annotations__ = { __annotations__ = {
'box_coder': det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher, "proposal_matcher": det_utils.Matcher,
} }
def __init__(self, backbone, num_classes, def __init__(
# transform parameters self,
min_size=800, max_size=1333, backbone,
image_mean=None, image_std=None, num_classes,
# Anchor parameters # transform parameters
anchor_generator=None, head=None, min_size=800,
proposal_matcher=None, max_size=1333,
score_thresh=0.05, image_mean=None,
nms_thresh=0.5, image_std=None,
detections_per_img=300, # Anchor parameters
fg_iou_thresh=0.5, bg_iou_thresh=0.4, anchor_generator=None,
topk_candidates=1000): head=None,
proposal_matcher=None,
score_thresh=0.05,
nms_thresh=0.5,
detections_per_img=300,
fg_iou_thresh=0.5,
bg_iou_thresh=0.4,
topk_candidates=1000,
):
super().__init__() super().__init__()
if not hasattr(backbone, "out_channels"): if not hasattr(backbone, "out_channels"):
raise ValueError( raise ValueError(
"backbone should contain an attribute out_channels " "backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the " "specifying the number of output channels (assumed to be the "
"same for all the levels)") "same for all the levels)"
)
self.backbone = backbone self.backbone = backbone
assert isinstance(anchor_generator, (AnchorGenerator, type(None))) assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
...@@ -340,9 +348,7 @@ class RetinaNet(nn.Module): ...@@ -340,9 +348,7 @@ class RetinaNet(nn.Module):
if anchor_generator is None: if anchor_generator is None:
anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
anchor_generator = AnchorGenerator( anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
anchor_sizes, aspect_ratios
)
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
if head is None: if head is None:
...@@ -385,20 +391,21 @@ class RetinaNet(nn.Module): ...@@ -385,20 +391,21 @@ class RetinaNet(nn.Module):
# type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
matched_idxs = [] matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0: if targets_per_image["boxes"].numel() == 0:
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, matched_idxs.append(
device=anchors_per_image.device)) torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
)
continue continue
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
matched_idxs.append(self.proposal_matcher(match_quality_matrix)) matched_idxs.append(self.proposal_matcher(match_quality_matrix))
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)
def postprocess_detections(self, head_outputs, anchors, image_shapes): def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
class_logits = head_outputs['cls_logits'] class_logits = head_outputs["cls_logits"]
box_regression = head_outputs['bbox_regression'] box_regression = head_outputs["bbox_regression"]
num_images = len(image_shapes) num_images = len(image_shapes)
...@@ -413,8 +420,9 @@ class RetinaNet(nn.Module): ...@@ -413,8 +420,9 @@ class RetinaNet(nn.Module):
image_scores = [] image_scores = []
image_labels = [] image_labels = []
for box_regression_per_level, logits_per_level, anchors_per_level in \ for box_regression_per_level, logits_per_level, anchors_per_level in zip(
zip(box_regression_per_image, logits_per_image, anchors_per_image): box_regression_per_image, logits_per_image, anchors_per_image
):
num_classes = logits_per_level.shape[-1] num_classes = logits_per_level.shape[-1]
# remove low scoring boxes # remove low scoring boxes
...@@ -428,11 +436,12 @@ class RetinaNet(nn.Module): ...@@ -428,11 +436,12 @@ class RetinaNet(nn.Module):
scores_per_level, idxs = scores_per_level.topk(num_topk) scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs] topk_idxs = topk_idxs[idxs]
anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode='floor') anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
labels_per_level = topk_idxs % num_classes labels_per_level = topk_idxs % num_classes
boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs], boxes_per_level = self.box_coder.decode_single(
anchors_per_level[anchor_idxs]) box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
)
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)
image_boxes.append(boxes_per_level) image_boxes.append(boxes_per_level)
...@@ -445,13 +454,15 @@ class RetinaNet(nn.Module): ...@@ -445,13 +454,15 @@ class RetinaNet(nn.Module):
# non-maximum suppression # non-maximum suppression
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
keep = keep[:self.detections_per_img] keep = keep[: self.detections_per_img]
detections.append({ detections.append(
'boxes': image_boxes[keep], {
'scores': image_scores[keep], "boxes": image_boxes[keep],
'labels': image_labels[keep], "scores": image_scores[keep],
}) "labels": image_labels[keep],
}
)
return detections return detections
...@@ -478,12 +489,11 @@ class RetinaNet(nn.Module): ...@@ -478,12 +489,11 @@ class RetinaNet(nn.Module):
boxes = target["boxes"] boxes = target["boxes"]
if isinstance(boxes, torch.Tensor): if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4: if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor" raise ValueError(
"of shape [N, 4], got {:}.".format( "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape)
boxes.shape)) )
else: else:
raise ValueError("Expected target boxes to be of type " raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))
"Tensor, got {:}.".format(type(boxes)))
# get the original image sizes # get the original image sizes
original_image_sizes: List[Tuple[int, int]] = [] original_image_sizes: List[Tuple[int, int]] = []
...@@ -505,14 +515,15 @@ class RetinaNet(nn.Module): ...@@ -505,14 +515,15 @@ class RetinaNet(nn.Module):
# print the first degenerate box # print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist() degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width." raise ValueError(
" Found invalid box {} for target at index {}." "All bounding boxes should have positive height and width."
.format(degen_bb, target_idx)) " Found invalid box {} for target at index {}.".format(degen_bb, target_idx)
)
# get the features from the backbone # get the features from the backbone
features = self.backbone(images.tensors) features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor): if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)]) features = OrderedDict([("0", features)])
# TODO: Do we want a list or a dict? # TODO: Do we want a list or a dict?
features = list(features.values()) features = list(features.values())
...@@ -536,7 +547,7 @@ class RetinaNet(nn.Module): ...@@ -536,7 +547,7 @@ class RetinaNet(nn.Module):
HW = 0 HW = 0
for v in num_anchors_per_level: for v in num_anchors_per_level:
HW += v HW += v
HWA = head_outputs['cls_logits'].size(1) HWA = head_outputs["cls_logits"].size(1)
A = HWA // HW A = HWA // HW
num_anchors_per_level = [hw * A for hw in num_anchors_per_level] num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
...@@ -559,13 +570,13 @@ class RetinaNet(nn.Module): ...@@ -559,13 +570,13 @@ class RetinaNet(nn.Module):
model_urls = { model_urls = {
'retinanet_resnet50_fpn_coco': "retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth',
} }
def retinanet_resnet50_fpn(pretrained=False, progress=True, def retinanet_resnet50_fpn(
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
""" """
Constructs a RetinaNet model with a ResNet-50-FPN backbone. Constructs a RetinaNet model with a ResNet-50-FPN backbone.
...@@ -613,18 +624,23 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, ...@@ -613,18 +624,23 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
# skip P2 because it generates too many anchors (according to their paper) # skip P2 because it generates too many anchors (according to their paper)
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, returned_layers=[2, 3, 4], backbone = resnet_fpn_backbone(
extra_blocks=LastLevelP6P7(256, 256), trainable_layers=trainable_backbone_layers) "resnet50",
pretrained_backbone,
returned_layers=[2, 3, 4],
extra_blocks=LastLevelP6P7(256, 256),
trainable_layers=trainable_backbone_layers,
)
model = RetinaNet(backbone, num_classes, **kwargs) model = RetinaNet(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
import torch from typing import Optional, List, Dict, Tuple
import torchvision
import torch
import torch.nn.functional as F import torch.nn.functional as F
import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from torchvision.ops import roi_align from torchvision.ops import roi_align
from . import _utils as det_utils from . import _utils as det_utils
from typing import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
...@@ -46,7 +43,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): ...@@ -46,7 +43,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
box_regression[sampled_pos_inds_subset, labels_pos], box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset], regression_targets[sampled_pos_inds_subset],
beta=1 / 9, beta=1 / 9,
reduction='sum', reduction="sum",
) )
box_loss = box_loss / labels.numel() box_loss = box_loss / labels.numel()
...@@ -95,7 +92,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): ...@@ -95,7 +92,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
matched_idxs = matched_idxs.to(boxes) matched_idxs = matched_idxs.to(boxes)
rois = torch.cat([matched_idxs[:, None], boxes], dim=1) rois = torch.cat([matched_idxs[:, None], boxes], dim=1)
gt_masks = gt_masks[:, None].to(rois) gt_masks = gt_masks[:, None].to(rois)
return roi_align(gt_masks, rois, (M, M), 1.)[:, 0] return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0]
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
...@@ -113,8 +110,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs ...@@ -113,8 +110,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
discretization_size = mask_logits.shape[-1] discretization_size = mask_logits.shape[-1]
labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)] labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)]
mask_targets = [ mask_targets = [
project_masks_on_boxes(m, p, i, discretization_size) project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
for m, p, i in zip(gt_masks, proposals, mask_matched_idxs)
] ]
labels = torch.cat(labels, dim=0) labels = torch.cat(labels, dim=0)
...@@ -167,59 +163,72 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size): ...@@ -167,59 +163,72 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size):
return heatmaps, valid return heatmaps, valid
def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, def _onnx_heatmaps_to_keypoints(
widths_i, heights_i, offset_x_i, offset_y_i): maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i
):
num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64) num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64)
width_correction = widths_i / roi_map_width width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height height_correction = heights_i / roi_map_height
roi_map = F.interpolate( roi_map = F.interpolate(
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0] maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False
)[:, 0]
w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = (pos % w) x_int = pos % w
y_int = ((pos - x_int) // w) y_int = (pos - x_int) // w
x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * \ x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to(
width_correction.to(dtype=torch.float32) dtype=torch.float32
y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * \ )
height_correction.to(dtype=torch.float32) y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to(
dtype=torch.float32
)
xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32) xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32)
xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32) xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32)
xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32) xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32)
xy_preds_i = torch.stack([xy_preds_i_0.to(dtype=torch.float32), xy_preds_i = torch.stack(
xy_preds_i_1.to(dtype=torch.float32), [
xy_preds_i_2.to(dtype=torch.float32)], 0) xy_preds_i_0.to(dtype=torch.float32),
xy_preds_i_1.to(dtype=torch.float32),
xy_preds_i_2.to(dtype=torch.float32),
],
0,
)
# TODO: simplify when indexing without rank will be supported by ONNX # TODO: simplify when indexing without rank will be supported by ONNX
base = num_keypoints * num_keypoints + num_keypoints + 1 base = num_keypoints * num_keypoints + num_keypoints + 1
ind = torch.arange(num_keypoints) ind = torch.arange(num_keypoints)
ind = ind.to(dtype=torch.int64) * base ind = ind.to(dtype=torch.int64) * base
end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \ end_scores_i = (
.index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64)) roi_map.index_select(1, y_int.to(dtype=torch.int64))
.index_select(2, x_int.to(dtype=torch.int64))
.view(-1)
.index_select(0, ind.to(dtype=torch.int64))
)
return xy_preds_i, end_scores_i return xy_preds_i, end_scores_i
@torch.jit._script_if_tracing @torch.jit._script_if_tracing
def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil, def _onnx_heatmaps_to_keypoints_loop(
widths, heights, offset_x, offset_y, num_keypoints): maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints
):
xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device)
end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device) end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device)
for i in range(int(rois.size(0))): for i in range(int(rois.size(0))):
xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(maps, maps[i], xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(
widths_ceil[i], heights_ceil[i], maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i]
widths[i], heights[i], )
offset_x[i], offset_y[i]) xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0)
xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), end_scores = torch.cat(
xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0
end_scores = torch.cat((end_scores.to(dtype=torch.float32), )
end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0)
return xy_preds, end_scores return xy_preds, end_scores
...@@ -246,10 +255,17 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -246,10 +255,17 @@ def heatmaps_to_keypoints(maps, rois):
num_keypoints = maps.shape[1] num_keypoints = maps.shape[1]
if torchvision._is_tracing(): if torchvision._is_tracing():
xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(maps, rois, xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(
widths_ceil, heights_ceil, widths, heights, maps,
offset_x, offset_y, rois,
torch.scalar_tensor(num_keypoints, dtype=torch.int64)) widths_ceil,
heights_ceil,
widths,
heights,
offset_x,
offset_y,
torch.scalar_tensor(num_keypoints, dtype=torch.int64),
)
return xy_preds.permute(0, 2, 1), end_scores return xy_preds.permute(0, 2, 1), end_scores
xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device)
...@@ -260,13 +276,14 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -260,13 +276,14 @@ def heatmaps_to_keypoints(maps, rois):
width_correction = widths[i] / roi_map_width width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height height_correction = heights[i] / roi_map_height
roi_map = F.interpolate( roi_map = F.interpolate(
maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0] maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False
)[:, 0]
# roi_map_probs = scores_to_probs(roi_map.copy()) # roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2] w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
x_int = pos % w x_int = pos % w
y_int = torch.div(pos - x_int, w, rounding_mode='floor') y_int = torch.div(pos - x_int, w, rounding_mode="floor")
# assert (roi_map_probs[k, y_int, x_int] == # assert (roi_map_probs[k, y_int, x_int] ==
# roi_map_probs[k, :, :].max()) # roi_map_probs[k, :, :].max())
x = (x_int.float() + 0.5) * width_correction x = (x_int.float() + 0.5) * width_correction
...@@ -288,9 +305,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched ...@@ -288,9 +305,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
valid = [] valid = []
for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs):
kp = gt_kp_in_image[midx] kp = gt_kp_in_image[midx]
heatmaps_per_image, valid_per_image = keypoints_to_heatmap( heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size)
kp, proposals_per_image, discretization_size
)
heatmaps.append(heatmaps_per_image.view(-1)) heatmaps.append(heatmaps_per_image.view(-1))
valid.append(valid_per_image.view(-1)) valid.append(valid_per_image.view(-1))
...@@ -327,10 +342,10 @@ def keypointrcnn_inference(x, boxes): ...@@ -327,10 +342,10 @@ def keypointrcnn_inference(x, boxes):
def _onnx_expand_boxes(boxes, scale): def _onnx_expand_boxes(boxes, scale):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5 x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5 y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
w_half = w_half.to(dtype=torch.float32) * scale w_half = w_half.to(dtype=torch.float32) * scale
h_half = h_half.to(dtype=torch.float32) * scale h_half = h_half.to(dtype=torch.float32) * scale
...@@ -350,10 +365,10 @@ def expand_boxes(boxes, scale): ...@@ -350,10 +365,10 @@ def expand_boxes(boxes, scale):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
if torchvision._is_tracing(): if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale) return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5 x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5
y_c = (boxes[:, 3] + boxes[:, 1]) * .5 y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5
w_half *= scale w_half *= scale
h_half *= scale h_half *= scale
...@@ -395,7 +410,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): ...@@ -395,7 +410,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, -1, -1)) mask = mask.expand((1, 1, -1, -1))
# Resize mask # Resize mask
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False)
mask = mask[0][0] mask = mask[0][0]
im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
...@@ -404,9 +419,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): ...@@ -404,9 +419,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
y_0 = max(box[1], 0) y_0 = max(box[1], 0)
y_1 = min(box[3] + 1, im_h) y_1 = min(box[3] + 1, im_h)
im_mask[y_0:y_1, x_0:x_1] = mask[ im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
(y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0])
]
return im_mask return im_mask
...@@ -414,8 +427,8 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): ...@@ -414,8 +427,8 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
one = torch.ones(1, dtype=torch.int64) one = torch.ones(1, dtype=torch.int64)
zero = torch.zeros(1, dtype=torch.int64) zero = torch.zeros(1, dtype=torch.int64)
w = (box[2] - box[0] + one) w = box[2] - box[0] + one
h = (box[3] - box[1] + one) h = box[3] - box[1] + one
w = torch.max(torch.cat((w, one))) w = torch.max(torch.cat((w, one)))
h = torch.max(torch.cat((h, one))) h = torch.max(torch.cat((h, one)))
...@@ -423,7 +436,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): ...@@ -423,7 +436,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, mask.size(0), mask.size(1))) mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
# Resize mask # Resize mask
mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False)
mask = mask[0][0] mask = mask[0][0]
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
...@@ -431,23 +444,18 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): ...@@ -431,23 +444,18 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero))) y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero)))
y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0)))) y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0))))
unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]), unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])]
(x_0 - box[0]):(x_1 - box[0])]
# TODO : replace below with a dynamic padding when support is added in ONNX # TODO : replace below with a dynamic padding when support is added in ONNX
# pad y # pad y
zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1)) zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1))
zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1)) zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1))
concat_0 = torch.cat((zeros_y0, concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :]
unpaded_im_mask.to(dtype=torch.float32),
zeros_y1), 0)[0:im_h, :]
# pad x # pad x
zeros_x0 = torch.zeros(concat_0.size(0), x_0) zeros_x0 = torch.zeros(concat_0.size(0), x_0)
zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1) zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1)
im_mask = torch.cat((zeros_x0, im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w]
concat_0,
zeros_x1), 1)[:, :im_w]
return im_mask return im_mask
...@@ -468,13 +476,10 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): ...@@ -468,13 +476,10 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
im_h, im_w = img_shape im_h, im_w = img_shape
if torchvision._is_tracing(): if torchvision._is_tracing():
return _onnx_paste_masks_in_image_loop(masks, boxes, return _onnx_paste_masks_in_image_loop(
torch.scalar_tensor(im_h, dtype=torch.int64), masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64)
torch.scalar_tensor(im_w, dtype=torch.int64))[:, None] )[:, None]
res = [ res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)]
paste_mask_in_image(m[0], b, im_h, im_w)
for m, b in zip(masks, boxes)
]
if len(res) > 0: if len(res) > 0:
ret = torch.stack(res, dim=0)[:, None] ret = torch.stack(res, dim=0)[:, None]
else: else:
...@@ -484,46 +489,44 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): ...@@ -484,46 +489,44 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1):
class RoIHeads(nn.Module): class RoIHeads(nn.Module):
__annotations__ = { __annotations__ = {
'box_coder': det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher, "proposal_matcher": det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
} }
def __init__(self, def __init__(
box_roi_pool, self,
box_head, box_roi_pool,
box_predictor, box_head,
# Faster R-CNN training box_predictor,
fg_iou_thresh, bg_iou_thresh, # Faster R-CNN training
batch_size_per_image, positive_fraction, fg_iou_thresh,
bbox_reg_weights, bg_iou_thresh,
# Faster R-CNN inference batch_size_per_image,
score_thresh, positive_fraction,
nms_thresh, bbox_reg_weights,
detections_per_img, # Faster R-CNN inference
# Mask score_thresh,
mask_roi_pool=None, nms_thresh,
mask_head=None, detections_per_img,
mask_predictor=None, # Mask
keypoint_roi_pool=None, mask_roi_pool=None,
keypoint_head=None, mask_head=None,
keypoint_predictor=None, mask_predictor=None,
): keypoint_roi_pool=None,
keypoint_head=None,
keypoint_predictor=None,
):
super(RoIHeads, self).__init__() super(RoIHeads, self).__init__()
self.box_similarity = box_ops.box_iou self.box_similarity = box_ops.box_iou
# assign ground-truth boxes for each proposal # assign ground-truth boxes for each proposal
self.proposal_matcher = det_utils.Matcher( self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
fg_iou_thresh,
bg_iou_thresh,
allow_low_quality_matches=False)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
batch_size_per_image,
positive_fraction)
if bbox_reg_weights is None: if bbox_reg_weights is None:
bbox_reg_weights = (10., 10., 5., 5.) bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
self.box_coder = det_utils.BoxCoder(bbox_reg_weights) self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
self.box_roi_pool = box_roi_pool self.box_roi_pool = box_roi_pool
...@@ -572,9 +575,7 @@ class RoIHeads(nn.Module): ...@@ -572,9 +575,7 @@ class RoIHeads(nn.Module):
clamped_matched_idxs_in_image = torch.zeros( clamped_matched_idxs_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device (proposals_in_image.shape[0],), dtype=torch.int64, device=device
) )
labels_in_image = torch.zeros( labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
else: else:
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image) match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
...@@ -601,19 +602,14 @@ class RoIHeads(nn.Module): ...@@ -601,19 +602,14 @@ class RoIHeads(nn.Module):
# type: (List[Tensor]) -> List[Tensor] # type: (List[Tensor]) -> List[Tensor]
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = [] sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate( for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
zip(sampled_pos_inds, sampled_neg_inds)
):
img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0] img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
sampled_inds.append(img_sampled_inds) sampled_inds.append(img_sampled_inds)
return sampled_inds return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes): def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor]) -> List[Tensor] # type: (List[Tensor], List[Tensor]) -> List[Tensor]
proposals = [ proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
torch.cat((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes)
]
return proposals return proposals
...@@ -625,10 +621,11 @@ class RoIHeads(nn.Module): ...@@ -625,10 +621,11 @@ class RoIHeads(nn.Module):
if self.has_mask(): if self.has_mask():
assert all(["masks" in t for t in targets]) assert all(["masks" in t for t in targets])
def select_training_samples(self, def select_training_samples(
proposals, # type: List[Tensor] self,
targets # type: Optional[List[Dict[str, Tensor]]] proposals, # type: List[Tensor]
): targets, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
self.check_targets(targets) self.check_targets(targets)
assert targets is not None assert targets is not None
...@@ -661,12 +658,13 @@ class RoIHeads(nn.Module): ...@@ -661,12 +658,13 @@ class RoIHeads(nn.Module):
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
return proposals, matched_idxs, labels, regression_targets return proposals, matched_idxs, labels, regression_targets
def postprocess_detections(self, def postprocess_detections(
class_logits, # type: Tensor self,
box_regression, # type: Tensor class_logits, # type: Tensor
proposals, # type: List[Tensor] box_regression, # type: Tensor
image_shapes # type: List[Tuple[int, int]] proposals, # type: List[Tensor]
): image_shapes, # type: List[Tuple[int, int]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]] # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
device = class_logits.device device = class_logits.device
num_classes = class_logits.shape[-1] num_classes = class_logits.shape[-1]
...@@ -710,7 +708,7 @@ class RoIHeads(nn.Module): ...@@ -710,7 +708,7 @@ class RoIHeads(nn.Module):
# non-maximum suppression, independently done per class # non-maximum suppression, independently done per class
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# keep only topk scoring predictions # keep only topk scoring predictions
keep = keep[:self.detections_per_img] keep = keep[: self.detections_per_img]
boxes, scores, labels = boxes[keep], scores[keep], labels[keep] boxes, scores, labels = boxes[keep], scores[keep], labels[keep]
all_boxes.append(boxes) all_boxes.append(boxes)
...@@ -719,12 +717,13 @@ class RoIHeads(nn.Module): ...@@ -719,12 +717,13 @@ class RoIHeads(nn.Module):
return all_boxes, all_scores, all_labels return all_boxes, all_scores, all_labels
def forward(self, def forward(
features, # type: Dict[str, Tensor] self,
proposals, # type: List[Tensor] features, # type: Dict[str, Tensor]
image_shapes, # type: List[Tuple[int, int]] proposals, # type: List[Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]] image_shapes, # type: List[Tuple[int, int]]
): targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
""" """
Args: Args:
...@@ -737,10 +736,10 @@ class RoIHeads(nn.Module): ...@@ -737,10 +736,10 @@ class RoIHeads(nn.Module):
for t in targets: for t in targets:
# TODO: https://github.com/pytorch/pytorch/issues/26731 # TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half) floating_point_types = (torch.float, torch.double, torch.half)
assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' assert t["boxes"].dtype in floating_point_types, "target boxes must of float type"
assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' assert t["labels"].dtype == torch.int64, "target labels must of int64 type"
if self.has_keypoint(): if self.has_keypoint():
assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' assert t["keypoints"].dtype == torch.float32, "target keypoints must of float type"
if self.training: if self.training:
proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets)
...@@ -757,12 +756,8 @@ class RoIHeads(nn.Module): ...@@ -757,12 +756,8 @@ class RoIHeads(nn.Module):
losses = {} losses = {}
if self.training: if self.training:
assert labels is not None and regression_targets is not None assert labels is not None and regression_targets is not None
loss_classifier, loss_box_reg = fastrcnn_loss( loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets)
class_logits, box_regression, labels, regression_targets) losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
losses = {
"loss_classifier": loss_classifier,
"loss_box_reg": loss_box_reg
}
else: else:
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes)
num_images = len(boxes) num_images = len(boxes)
...@@ -805,12 +800,8 @@ class RoIHeads(nn.Module): ...@@ -805,12 +800,8 @@ class RoIHeads(nn.Module):
gt_masks = [t["masks"] for t in targets] gt_masks = [t["masks"] for t in targets]
gt_labels = [t["labels"] for t in targets] gt_labels = [t["labels"] for t in targets]
rcnn_loss_mask = maskrcnn_loss( rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs)
mask_logits, mask_proposals, loss_mask = {"loss_mask": rcnn_loss_mask}
gt_masks, gt_labels, pos_matched_idxs)
loss_mask = {
"loss_mask": rcnn_loss_mask
}
else: else:
labels = [r["labels"] for r in result] labels = [r["labels"] for r in result]
masks_probs = maskrcnn_inference(mask_logits, labels) masks_probs = maskrcnn_inference(mask_logits, labels)
...@@ -821,8 +812,11 @@ class RoIHeads(nn.Module): ...@@ -821,8 +812,11 @@ class RoIHeads(nn.Module):
# keep none checks in if conditional so torchscript will conditionally # keep none checks in if conditional so torchscript will conditionally
# compile each branch # compile each branch
if self.keypoint_roi_pool is not None and self.keypoint_head is not None \ if (
and self.keypoint_predictor is not None: self.keypoint_roi_pool is not None
and self.keypoint_head is not None
and self.keypoint_predictor is not None
):
keypoint_proposals = [p["boxes"] for p in result] keypoint_proposals = [p["boxes"] for p in result]
if self.training: if self.training:
# during training, only focus on positive boxes # during training, only focus on positive boxes
...@@ -848,11 +842,9 @@ class RoIHeads(nn.Module): ...@@ -848,11 +842,9 @@ class RoIHeads(nn.Module):
gt_keypoints = [t["keypoints"] for t in targets] gt_keypoints = [t["keypoints"] for t in targets]
rcnn_loss_keypoint = keypointrcnn_loss( rcnn_loss_keypoint = keypointrcnn_loss(
keypoint_logits, keypoint_proposals, keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs
gt_keypoints, pos_matched_idxs) )
loss_keypoint = { loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint}
"loss_keypoint": rcnn_loss_keypoint
}
else: else:
assert keypoint_logits is not None assert keypoint_logits is not None
assert keypoint_proposals is not None assert keypoint_proposals is not None
......
import torch from typing import List, Optional, Dict, Tuple
from torch.nn import functional as F
from torch import nn, Tensor
import torch
import torchvision import torchvision
from torch import nn, Tensor
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops from torchvision.ops import boxes as box_ops
from . import _utils as det_utils from . import _utils as det_utils
from .image_list import ImageList
from typing import List, Optional, Dict, Tuple
# Import AnchorGenerator to keep compatibility. # Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .image_list import ImageList
@torch.jit.unused @torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int] # type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(torch.cat( pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
num_anchors), 0))
return num_anchors, pre_nms_top_n return num_anchors, pre_nms_top_n
...@@ -37,13 +35,9 @@ class RPNHead(nn.Module): ...@@ -37,13 +35,9 @@ class RPNHead(nn.Module):
def __init__(self, in_channels, num_anchors): def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__() super(RPNHead, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d( self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
in_channels, num_anchors * 4, kernel_size=1, stride=1
)
for layer in self.children(): for layer in self.children():
torch.nn.init.normal_(layer.weight, std=0.01) torch.nn.init.normal_(layer.weight, std=0.01)
...@@ -76,21 +70,15 @@ def concat_box_prediction_layers(box_cls, box_regression): ...@@ -76,21 +70,15 @@ def concat_box_prediction_layers(box_cls, box_regression):
# same format as the labels. Note that the labels are computed for # same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation # all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression # for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip( for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
box_cls, box_regression
):
N, AxC, H, W = box_cls_per_level.shape N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1] Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4 A = Ax4 // 4
C = AxC // A C = AxC // A
box_cls_per_level = permute_and_flatten( box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
box_cls_per_level, N, A, C, H, W
)
box_cls_flattened.append(box_cls_per_level) box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten( box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
box_regression_per_level, N, A, 4, H, W
)
box_regression_flattened.append(box_regression_per_level) box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to # concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps # take into account the way the labels were generated (with all feature maps
...@@ -125,22 +113,30 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -125,22 +113,30 @@ class RegionProposalNetwork(torch.nn.Module):
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
""" """
__annotations__ = { __annotations__ = {
'box_coder': det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher, "proposal_matcher": det_utils.Matcher,
'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
'pre_nms_top_n': Dict[str, int], "pre_nms_top_n": Dict[str, int],
'post_nms_top_n': Dict[str, int], "post_nms_top_n": Dict[str, int],
} }
def __init__(self, def __init__(
anchor_generator, self,
head, anchor_generator,
# head,
fg_iou_thresh, bg_iou_thresh, #
batch_size_per_image, positive_fraction, fg_iou_thresh,
# bg_iou_thresh,
pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0): batch_size_per_image,
positive_fraction,
#
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
score_thresh=0.0,
):
super(RegionProposalNetwork, self).__init__() super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.head = head self.head = head
...@@ -155,9 +151,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -155,9 +151,7 @@ class RegionProposalNetwork(torch.nn.Module):
allow_low_quality_matches=True, allow_low_quality_matches=True,
) )
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
batch_size_per_image, positive_fraction
)
# used during testing # used during testing
self._pre_nms_top_n = pre_nms_top_n self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n self._post_nms_top_n = post_nms_top_n
...@@ -167,13 +161,13 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -167,13 +161,13 @@ class RegionProposalNetwork(torch.nn.Module):
def pre_nms_top_n(self): def pre_nms_top_n(self):
if self.training: if self.training:
return self._pre_nms_top_n['training'] return self._pre_nms_top_n["training"]
return self._pre_nms_top_n['testing'] return self._pre_nms_top_n["testing"]
def post_nms_top_n(self): def post_nms_top_n(self):
if self.training: if self.training:
return self._post_nms_top_n['training'] return self._post_nms_top_n["training"]
return self._post_nms_top_n['testing'] return self._post_nms_top_n["testing"]
def assign_targets_to_anchors(self, anchors, targets): def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
...@@ -235,8 +229,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -235,8 +229,7 @@ class RegionProposalNetwork(torch.nn.Module):
objectness = objectness.reshape(num_images, -1) objectness = objectness.reshape(num_images, -1)
levels = [ levels = [
torch.full((n,), idx, dtype=torch.int64, device=device) torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
for idx, n in enumerate(num_anchors_per_level)
] ]
levels = torch.cat(levels, 0) levels = torch.cat(levels, 0)
levels = levels.reshape(1, -1).expand_as(objectness) levels = levels.reshape(1, -1).expand_as(objectness)
...@@ -271,7 +264,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -271,7 +264,7 @@ class RegionProposalNetwork(torch.nn.Module):
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions # keep only topk scoring predictions
keep = keep[:self.post_nms_top_n()] keep = keep[: self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep] boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes) final_boxes.append(boxes)
...@@ -303,24 +296,26 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -303,24 +296,26 @@ class RegionProposalNetwork(torch.nn.Module):
labels = torch.cat(labels, dim=0) labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0) regression_targets = torch.cat(regression_targets, dim=0)
box_loss = F.smooth_l1_loss( box_loss = (
pred_bbox_deltas[sampled_pos_inds], F.smooth_l1_loss(
regression_targets[sampled_pos_inds], pred_bbox_deltas[sampled_pos_inds],
beta=1 / 9, regression_targets[sampled_pos_inds],
reduction='sum', beta=1 / 9,
) / (sampled_inds.numel()) reduction="sum",
)
objectness_loss = F.binary_cross_entropy_with_logits( / (sampled_inds.numel())
objectness[sampled_inds], labels[sampled_inds]
) )
objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
return objectness_loss, box_loss return objectness_loss, box_loss
def forward(self, def forward(
images, # type: ImageList self,
features, # type: Dict[str, Tensor] images, # type: ImageList
targets=None # type: Optional[List[Dict[str, Tensor]]] features, # type: Dict[str, Tensor]
): targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
""" """
Args: Args:
...@@ -346,8 +341,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -346,8 +341,7 @@ class RegionProposalNetwork(torch.nn.Module):
num_images = len(anchors) num_images = len(anchors)
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
objectness, pred_bbox_deltas = \ objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals # apply pred_bbox_deltas to anchors to obtain the decoded proposals
# note that we detach the deltas because Faster R-CNN do not backprop through # note that we detach the deltas because Faster R-CNN do not backprop through
# the proposals # the proposals
...@@ -361,7 +355,8 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -361,7 +355,8 @@ class RegionProposalNetwork(torch.nn.Module):
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss( loss_objectness, loss_rpn_box_reg = self.compute_loss(
objectness, pred_bbox_deltas, labels, regression_targets) objectness, pred_bbox_deltas, labels, regression_targets
)
losses = { losses = {
"loss_objectness": loss_objectness, "loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg, "loss_rpn_box_reg": loss_rpn_box_reg,
......
import torch
import torch.nn.functional as F
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from torch import nn, Tensor
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import boxes as box_ops
from .. import vgg
from . import _utils as det_utils from . import _utils as det_utils
from .anchor_utils import DefaultBoxGenerator from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers from .backbone_utils import _validate_trainable_layers
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
from .. import vgg
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops import boxes as box_ops
__all__ = ['SSD', 'ssd300_vgg16'] __all__ = ["SSD", "ssd300_vgg16"]
model_urls = { model_urls = {
'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth', "ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
} }
backbone_urls = { backbone_urls = {
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
# same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
'vgg16_features': 'https://download.pytorch.org/models/vgg16_features-amdegroot.pth' "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth"
} }
...@@ -43,8 +43,8 @@ class SSDHead(nn.Module): ...@@ -43,8 +43,8 @@ class SSDHead(nn.Module):
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
return { return {
'bbox_regression': self.regression_head(x), "bbox_regression": self.regression_head(x),
'cls_logits': self.classification_head(x), "cls_logits": self.classification_head(x),
} }
...@@ -159,31 +159,38 @@ class SSD(nn.Module): ...@@ -159,31 +159,38 @@ class SSD(nn.Module):
proposals used during the training of the classification head. It is used to estimate the negative to proposals used during the training of the classification head. It is used to estimate the negative to
positive ratio. positive ratio.
""" """
__annotations__ = { __annotations__ = {
'box_coder': det_utils.BoxCoder, "box_coder": det_utils.BoxCoder,
'proposal_matcher': det_utils.Matcher, "proposal_matcher": det_utils.Matcher,
} }
def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, def __init__(
size: Tuple[int, int], num_classes: int, self,
image_mean: Optional[List[float]] = None, image_std: Optional[List[float]] = None, backbone: nn.Module,
head: Optional[nn.Module] = None, anchor_generator: DefaultBoxGenerator,
score_thresh: float = 0.01, size: Tuple[int, int],
nms_thresh: float = 0.45, num_classes: int,
detections_per_img: int = 200, image_mean: Optional[List[float]] = None,
iou_thresh: float = 0.5, image_std: Optional[List[float]] = None,
topk_candidates: int = 400, head: Optional[nn.Module] = None,
positive_fraction: float = 0.25): score_thresh: float = 0.01,
nms_thresh: float = 0.45,
detections_per_img: int = 200,
iou_thresh: float = 0.5,
topk_candidates: int = 400,
positive_fraction: float = 0.25,
):
super().__init__() super().__init__()
self.backbone = backbone self.backbone = backbone
self.anchor_generator = anchor_generator self.anchor_generator = anchor_generator
self.box_coder = det_utils.BoxCoder(weights=(10., 10., 5., 5.)) self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0))
if head is None: if head is None:
if hasattr(backbone, 'out_channels'): if hasattr(backbone, "out_channels"):
out_channels = backbone.out_channels out_channels = backbone.out_channels
else: else:
out_channels = det_utils.retrieve_out_channels(backbone, size) out_channels = det_utils.retrieve_out_channels(backbone, size)
...@@ -200,8 +207,9 @@ class SSD(nn.Module): ...@@ -200,8 +207,9 @@ class SSD(nn.Module):
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
if image_std is None: if image_std is None:
image_std = [0.229, 0.224, 0.225] image_std = [0.229, 0.224, 0.225]
self.transform = GeneralizedRCNNTransform(min(size), max(size), image_mean, image_std, self.transform = GeneralizedRCNNTransform(
size_divisible=1, fixed_size=size) min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size
)
self.score_thresh = score_thresh self.score_thresh = score_thresh
self.nms_thresh = nms_thresh self.nms_thresh = nms_thresh
...@@ -213,45 +221,58 @@ class SSD(nn.Module): ...@@ -213,45 +221,58 @@ class SSD(nn.Module):
self._has_warned = False self._has_warned = False
@torch.jit.unused @torch.jit.unused
def eager_outputs(self, losses: Dict[str, Tensor], def eager_outputs(
detections: List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]]
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training: if self.training:
return losses return losses
return detections return detections
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor], def compute_loss(
matched_idxs: List[Tensor]) -> Dict[str, Tensor]: self,
bbox_regression = head_outputs['bbox_regression'] targets: List[Dict[str, Tensor]],
cls_logits = head_outputs['cls_logits'] head_outputs: Dict[str, Tensor],
anchors: List[Tensor],
matched_idxs: List[Tensor],
) -> Dict[str, Tensor]:
bbox_regression = head_outputs["bbox_regression"]
cls_logits = head_outputs["cls_logits"]
# Match original targets with default boxes # Match original targets with default boxes
num_foreground = 0 num_foreground = 0
bbox_loss = [] bbox_loss = []
cls_targets = [] cls_targets = []
for (targets_per_image, bbox_regression_per_image, cls_logits_per_image, anchors_per_image, for (
matched_idxs_per_image) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs): targets_per_image,
bbox_regression_per_image,
cls_logits_per_image,
anchors_per_image,
matched_idxs_per_image,
) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs):
# produce the matching between boxes and targets # produce the matching between boxes and targets
foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image] foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image]
num_foreground += foreground_matched_idxs_per_image.numel() num_foreground += foreground_matched_idxs_per_image.numel()
# Calculate regression loss # Calculate regression loss
matched_gt_boxes_per_image = targets_per_image['boxes'][foreground_matched_idxs_per_image] matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image]
bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]
target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image)
bbox_loss.append(torch.nn.functional.smooth_l1_loss( bbox_loss.append(
bbox_regression_per_image, torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum")
target_regression, )
reduction='sum'
))
# Estimate ground truth for class targets # Estimate ground truth for class targets
gt_classes_target = torch.zeros((cls_logits_per_image.size(0), ), dtype=targets_per_image['labels'].dtype, gt_classes_target = torch.zeros(
device=targets_per_image['labels'].device) (cls_logits_per_image.size(0),),
gt_classes_target[foreground_idxs_per_image] = \ dtype=targets_per_image["labels"].dtype,
targets_per_image['labels'][foreground_matched_idxs_per_image] device=targets_per_image["labels"].device,
)
gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][
foreground_matched_idxs_per_image
]
cls_targets.append(gt_classes_target) cls_targets.append(gt_classes_target)
bbox_loss = torch.stack(bbox_loss) bbox_loss = torch.stack(bbox_loss)
...@@ -259,30 +280,29 @@ class SSD(nn.Module): ...@@ -259,30 +280,29 @@ class SSD(nn.Module):
# Calculate classification loss # Calculate classification loss
num_classes = cls_logits.size(-1) num_classes = cls_logits.size(-1)
cls_loss = F.cross_entropy( cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view(
cls_logits.view(-1, num_classes), cls_targets.size()
cls_targets.view(-1), )
reduction='none'
).view(cls_targets.size())
# Hard Negative Sampling # Hard Negative Sampling
foreground_idxs = cls_targets > 0 foreground_idxs = cls_targets > 0
num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True) num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True)
# num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio
negative_loss = cls_loss.clone() negative_loss = cls_loss.clone()
negative_loss[foreground_idxs] = -float('inf') # use -inf to detect positive values that creeped in the sample negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample
values, idx = negative_loss.sort(1, descending=True) values, idx = negative_loss.sort(1, descending=True)
# background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values)) # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values))
background_idxs = idx.sort(1)[1] < num_negative background_idxs = idx.sort(1)[1] < num_negative
N = max(1, num_foreground) N = max(1, num_foreground)
return { return {
'bbox_regression': bbox_loss.sum() / N, "bbox_regression": bbox_loss.sum() / N,
'classification': (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N, "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N,
} }
def forward(self, images: List[Tensor], def forward(
targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]:
if self.training and targets is None: if self.training and targets is None:
raise ValueError("In training mode, targets should be passed") raise ValueError("In training mode, targets should be passed")
...@@ -292,12 +312,11 @@ class SSD(nn.Module): ...@@ -292,12 +312,11 @@ class SSD(nn.Module):
boxes = target["boxes"] boxes = target["boxes"]
if isinstance(boxes, torch.Tensor): if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4: if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor" raise ValueError(
"of shape [N, 4], got {:}.".format( "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape)
boxes.shape)) )
else: else:
raise ValueError("Expected target boxes to be of type " raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))
"Tensor, got {:}.".format(type(boxes)))
# get the original image sizes # get the original image sizes
original_image_sizes: List[Tuple[int, int]] = [] original_image_sizes: List[Tuple[int, int]] = []
...@@ -317,14 +336,15 @@ class SSD(nn.Module): ...@@ -317,14 +336,15 @@ class SSD(nn.Module):
if degenerate_boxes.any(): if degenerate_boxes.any():
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist() degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width." raise ValueError(
" Found invalid box {} for target at index {}." "All bounding boxes should have positive height and width."
.format(degen_bb, target_idx)) " Found invalid box {} for target at index {}.".format(degen_bb, target_idx)
)
# get the features from the backbone # get the features from the backbone
features = self.backbone(images.tensors) features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor): if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)]) features = OrderedDict([("0", features)])
features = list(features.values()) features = list(features.values())
...@@ -341,12 +361,13 @@ class SSD(nn.Module): ...@@ -341,12 +361,13 @@ class SSD(nn.Module):
matched_idxs = [] matched_idxs = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
if targets_per_image['boxes'].numel() == 0: if targets_per_image["boxes"].numel() == 0:
matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, matched_idxs.append(
device=anchors_per_image.device)) torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
)
continue continue
match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
matched_idxs.append(self.proposal_matcher(match_quality_matrix)) matched_idxs.append(self.proposal_matcher(match_quality_matrix))
losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs) losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs)
...@@ -361,10 +382,11 @@ class SSD(nn.Module): ...@@ -361,10 +382,11 @@ class SSD(nn.Module):
return losses, detections return losses, detections
return self.eager_outputs(losses, detections) return self.eager_outputs(losses, detections)
def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], def postprocess_detections(
image_shapes: List[Tuple[int, int]]) -> List[Dict[str, Tensor]]: self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]]
bbox_regression = head_outputs['bbox_regression'] ) -> List[Dict[str, Tensor]]:
pred_scores = F.softmax(head_outputs['cls_logits'], dim=-1) bbox_regression = head_outputs["bbox_regression"]
pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1)
num_classes = pred_scores.size(-1) num_classes = pred_scores.size(-1)
device = pred_scores.device device = pred_scores.device
...@@ -400,13 +422,15 @@ class SSD(nn.Module): ...@@ -400,13 +422,15 @@ class SSD(nn.Module):
# non-maximum suppression # non-maximum suppression
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
keep = keep[:self.detections_per_img] keep = keep[: self.detections_per_img]
detections.append({ detections.append(
'boxes': image_boxes[keep], {
'scores': image_scores[keep], "boxes": image_boxes[keep],
'labels': image_labels[keep], "scores": image_scores[keep],
}) "labels": image_labels[keep],
}
)
return detections return detections
...@@ -423,45 +447,47 @@ class SSDFeatureExtractorVGG(nn.Module): ...@@ -423,45 +447,47 @@ class SSDFeatureExtractorVGG(nn.Module):
self.scale_weight = nn.Parameter(torch.ones(512) * 20) self.scale_weight = nn.Parameter(torch.ones(512) * 20)
# Multiple Feature maps - page 4, Fig 2 of SSD paper # Multiple Feature maps - page 4, Fig 2 of SSD paper
self.features = nn.Sequential( self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3
*backbone[:maxpool4_pos] # until conv4_3
)
# SSD300 case - page 4, Fig 2 of SSD paper # SSD300 case - page 4, Fig 2 of SSD paper
extra = nn.ModuleList([ extra = nn.ModuleList(
nn.Sequential( [
nn.Conv2d(1024, 256, kernel_size=1), nn.Sequential(
nn.ReLU(inplace=True), nn.Conv2d(1024, 256, kernel_size=1),
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2 nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
), nn.ReLU(inplace=True),
nn.Sequential( ),
nn.Conv2d(512, 128, kernel_size=1), nn.Sequential(
nn.ReLU(inplace=True), nn.Conv2d(512, 128, kernel_size=1),
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2 nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
), nn.ReLU(inplace=True),
nn.Sequential( ),
nn.Conv2d(256, 128, kernel_size=1), nn.Sequential(
nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=1),
nn.Conv2d(128, 256, kernel_size=3), # conv10_2 nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=3), # conv10_2
), nn.ReLU(inplace=True),
nn.Sequential( ),
nn.Conv2d(256, 128, kernel_size=1), nn.Sequential(
nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=1),
nn.Conv2d(128, 256, kernel_size=3), # conv11_2 nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=3), # conv11_2
) nn.ReLU(inplace=True),
]) ),
]
)
if highres: if highres:
# Additional layers for the SSD512 case. See page 11, footernote 5. # Additional layers for the SSD512 case. See page 11, footernote 5.
extra.append(nn.Sequential( extra.append(
nn.Conv2d(256, 128, kernel_size=1), nn.Sequential(
nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=1),
nn.Conv2d(128, 256, kernel_size=4), # conv12_2 nn.ReLU(inplace=True),
nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=4), # conv12_2
)) nn.ReLU(inplace=True),
)
)
_xavier_init(extra) _xavier_init(extra)
fc = nn.Sequential( fc = nn.Sequential(
...@@ -469,13 +495,16 @@ class SSDFeatureExtractorVGG(nn.Module): ...@@ -469,13 +495,16 @@ class SSDFeatureExtractorVGG(nn.Module):
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7 nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
nn.ReLU(inplace=True) nn.ReLU(inplace=True),
) )
_xavier_init(fc) _xavier_init(fc)
extra.insert(0, nn.Sequential( extra.insert(
*backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5 0,
fc, nn.Sequential(
)) *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5
fc,
),
)
self.extra = extra self.extra = extra
def forward(self, x: Tensor) -> Dict[str, Tensor]: def forward(self, x: Tensor) -> Dict[str, Tensor]:
...@@ -495,7 +524,7 @@ class SSDFeatureExtractorVGG(nn.Module): ...@@ -495,7 +524,7 @@ class SSDFeatureExtractorVGG(nn.Module):
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
if backbone_name in backbone_urls: if backbone_name in backbone_urls:
# Use custom backbones more appropriate for SSD # Use custom backbones more appropriate for SSD
arch = backbone_name.split('_')[0] arch = backbone_name.split("_")[0]
backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress) state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress)
...@@ -519,8 +548,14 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained ...@@ -519,8 +548,14 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained
return SSDFeatureExtractorVGG(backbone, highres) return SSDFeatureExtractorVGG(backbone, highres)
def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91, def ssd300_vgg16(
pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): pretrained: bool = False,
progress: bool = True,
num_classes: int = 91,
pretrained_backbone: bool = True,
trainable_backbone_layers: Optional[int] = None,
**kwargs: Any,
):
"""Constructs an SSD model with input size 300x300 and a VGG16 backbone. """Constructs an SSD model with input size 300x300 and a VGG16 backbone.
Reference: `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_. Reference: `"SSD: Single Shot MultiBox Detector" <https://arxiv.org/abs/1512.02325>`_.
...@@ -569,16 +604,19 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i ...@@ -569,16 +604,19 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
warnings.warn("The size of the model is already fixed; ignoring the argument.") warnings.warn("The size of the model is already fixed; ignoring the argument.")
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5
)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers)
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], anchor_generator = DefaultBoxGenerator(
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
steps=[8, 16, 32, 64, 100, 300]) scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
steps=[8, 16, 32, 64, 100, 300],
)
defaults = { defaults = {
# Rescale the input in a way compatible to the backbone # Rescale the input in a way compatible to the backbone
...@@ -588,7 +626,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i ...@@ -588,7 +626,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs)
if pretrained: if pretrained:
weights_name = 'ssd300_vgg16_coco' weights_name = "ssd300_vgg16_coco"
if model_urls.get(weights_name, None) is None: if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name)) raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
......
import torch
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import nn, Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
from .. import mobilenet
from . import _utils as det_utils from . import _utils as det_utils
from .ssd import SSD, SSDScoringHead
from .anchor_utils import DefaultBoxGenerator from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers from .backbone_utils import _validate_trainable_layers
from .. import mobilenet from .ssd import SSD, SSDScoringHead
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
__all__ = ['ssdlite320_mobilenet_v3_large'] __all__ = ["ssdlite320_mobilenet_v3_large"]
model_urls = { model_urls = {
'ssdlite320_mobilenet_v3_large_coco': "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth"
'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth'
} }
# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper # Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper
def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, def _prediction_block(
norm_layer: Callable[..., nn.Module]) -> nn.Sequential: in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]
) -> nn.Sequential:
return nn.Sequential( return nn.Sequential(
# 3x3 depthwise with stride 1 and padding 1 # 3x3 depthwise with stride 1 and padding 1
ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, ConvNormActivation(
norm_layer=norm_layer, activation_layer=nn.ReLU6), in_channels,
in_channels,
kernel_size=kernel_size,
groups=in_channels,
norm_layer=norm_layer,
activation_layer=nn.ReLU6,
),
# 1x1 projetion to output channels # 1x1 projetion to output channels
nn.Conv2d(in_channels, out_channels, 1) nn.Conv2d(in_channels, out_channels, 1),
) )
...@@ -41,16 +46,23 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., ...@@ -41,16 +46,23 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[...,
intermediate_channels = out_channels // 2 intermediate_channels = out_channels // 2
return nn.Sequential( return nn.Sequential(
# 1x1 projection to half output channels # 1x1 projection to half output channels
ConvNormActivation(in_channels, intermediate_channels, kernel_size=1, ConvNormActivation(
norm_layer=norm_layer, activation_layer=activation), in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
),
# 3x3 depthwise with stride 2 and padding 1 # 3x3 depthwise with stride 2 and padding 1
ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, ConvNormActivation(
groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), intermediate_channels,
intermediate_channels,
kernel_size=3,
stride=2,
groups=intermediate_channels,
norm_layer=norm_layer,
activation_layer=activation,
),
# 1x1 projetion to output channels # 1x1 projetion to output channels
ConvNormActivation(intermediate_channels, out_channels, kernel_size=1, ConvNormActivation(
norm_layer=norm_layer, activation_layer=activation), intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
),
) )
...@@ -63,22 +75,24 @@ def _normal_init(conv: nn.Module): ...@@ -63,22 +75,24 @@ def _normal_init(conv: nn.Module):
class SSDLiteHead(nn.Module): class SSDLiteHead(nn.Module):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, def __init__(
norm_layer: Callable[..., nn.Module]): self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
):
super().__init__() super().__init__()
self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)
def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
return { return {
'bbox_regression': self.regression_head(x), "bbox_regression": self.regression_head(x),
'cls_logits': self.classification_head(x), "cls_logits": self.classification_head(x),
} }
class SSDLiteClassificationHead(SSDScoringHead): class SSDLiteClassificationHead(SSDScoringHead):
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, def __init__(
norm_layer: Callable[..., nn.Module]): self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
):
cls_logits = nn.ModuleList() cls_logits = nn.ModuleList()
for channels, anchors in zip(in_channels, num_anchors): for channels, anchors in zip(in_channels, num_anchors):
cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
...@@ -96,24 +110,33 @@ class SSDLiteRegressionHead(SSDScoringHead): ...@@ -96,24 +110,33 @@ class SSDLiteRegressionHead(SSDScoringHead):
class SSDLiteFeatureExtractorMobileNet(nn.Module): class SSDLiteFeatureExtractorMobileNet(nn.Module):
def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], width_mult: float = 1.0, def __init__(
min_depth: int = 16, **kwargs: Any): self,
backbone: nn.Module,
c4_pos: int,
norm_layer: Callable[..., nn.Module],
width_mult: float = 1.0,
min_depth: int = 16,
**kwargs: Any,
):
super().__init__() super().__init__()
assert not backbone[c4_pos].use_res_connect assert not backbone[c4_pos].use_res_connect
self.features = nn.Sequential( self.features = nn.Sequential(
# As described in section 6.3 of MobileNetV3 paper # As described in section 6.3 of MobileNetV3 paper
nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer
nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]), # from C4 depthwise until end
) )
get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731
extra = nn.ModuleList([ extra = nn.ModuleList(
_extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), [
_extra_block(get_depth(512), get_depth(256), norm_layer), _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
_extra_block(get_depth(256), get_depth(256), norm_layer), _extra_block(get_depth(512), get_depth(256), norm_layer),
_extra_block(get_depth(256), get_depth(128), norm_layer), _extra_block(get_depth(256), get_depth(256), norm_layer),
]) _extra_block(get_depth(256), get_depth(128), norm_layer),
]
)
_normal_init(extra) _normal_init(extra)
self.extra = extra self.extra = extra
...@@ -132,10 +155,17 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module): ...@@ -132,10 +155,17 @@ class SSDLiteFeatureExtractorMobileNet(nn.Module):
return OrderedDict([(str(i), v) for i, v in enumerate(output)]) return OrderedDict([(str(i), v) for i, v in enumerate(output)])
def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, def _mobilenet_extractor(
norm_layer: Callable[..., nn.Module], **kwargs: Any): backbone_name: str,
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, progress: bool,
norm_layer=norm_layer, **kwargs).features pretrained: bool,
trainable_layers: int,
norm_layer: Callable[..., nn.Module],
**kwargs: Any,
):
backbone = mobilenet.__dict__[backbone_name](
pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs
).features
if not pretrained: if not pretrained:
# Change the default initialization scheme if not pretrained # Change the default initialization scheme if not pretrained
_normal_init(backbone) _normal_init(backbone)
...@@ -156,10 +186,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t ...@@ -156,10 +186,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, def ssdlite320_mobilenet_v3_large(
pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None, pretrained: bool = False,
norm_layer: Optional[Callable[..., nn.Module]] = None, progress: bool = True,
**kwargs: Any): num_classes: int = 91,
pretrained_backbone: bool = False,
trainable_backbone_layers: Optional[int] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any,
):
"""Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at
`"Searching for MobileNetV3" `"Searching for MobileNetV3"
<https://arxiv.org/abs/1905.02244>`_ and <https://arxiv.org/abs/1905.02244>`_ and
...@@ -188,7 +223,8 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru ...@@ -188,7 +223,8 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
warnings.warn("The size of the model is already fixed; ignoring the argument.") warnings.warn("The size of the model is already fixed; ignoring the argument.")
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6
)
if pretrained: if pretrained:
pretrained_backbone = False pretrained_backbone = False
...@@ -199,8 +235,15 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru ...@@ -199,8 +235,15 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
if norm_layer is None: if norm_layer is None:
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, backbone = _mobilenet_extractor(
norm_layer, reduced_tail=reduce_tail, **kwargs) "mobilenet_v3_large",
progress,
pretrained_backbone,
trainable_backbone_layers,
norm_layer,
reduced_tail=reduce_tail,
**kwargs,
)
size = (320, 320) size = (320, 320)
anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
...@@ -219,11 +262,17 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru ...@@ -219,11 +262,17 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru
"image_std": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5],
} }
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
model = SSD(backbone, anchor_generator, size, num_classes, model = SSD(
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs) backbone,
anchor_generator,
size,
num_classes,
head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
**kwargs,
)
if pretrained: if pretrained:
weights_name = 'ssdlite320_mobilenet_v3_large_coco' weights_name = "ssdlite320_mobilenet_v3_large_coco"
if model_urls.get(weights_name, None) is None: if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name)) raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
......
import math import math
from typing import List, Tuple, Dict, Optional
import torch import torch
import torchvision import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from typing import List, Tuple, Dict, Optional
from .image_list import ImageList from .image_list import ImageList
from .roi_heads import paste_masks_in_image from .roi_heads import paste_masks_in_image
...@@ -12,6 +12,7 @@ from .roi_heads import paste_masks_in_image ...@@ -12,6 +12,7 @@ from .roi_heads import paste_masks_in_image
@torch.jit.unused @torch.jit.unused
def _get_shape_onnx(image: Tensor) -> Tensor: def _get_shape_onnx(image: Tensor) -> Tensor:
from torch.onnx import operators from torch.onnx import operators
return operators.shape_as_tensor(image)[-2:] return operators.shape_as_tensor(image)[-2:]
...@@ -21,10 +22,13 @@ def _fake_cast_onnx(v: Tensor) -> float: ...@@ -21,10 +22,13 @@ def _fake_cast_onnx(v: Tensor) -> float:
return v return v
def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: float, def _resize_image_and_masks(
target: Optional[Dict[str, Tensor]] = None, image: Tensor,
fixed_size: Optional[Tuple[int, int]] = None, self_min_size: float,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: self_max_size: float,
target: Optional[Dict[str, Tensor]] = None,
fixed_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
if torchvision._is_tracing(): if torchvision._is_tracing():
im_shape = _get_shape_onnx(image) im_shape = _get_shape_onnx(image)
else: else:
...@@ -46,16 +50,23 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: ...@@ -46,16 +50,23 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size:
scale_factor = scale.item() scale_factor = scale.item()
recompute_scale_factor = True recompute_scale_factor = True
image = torch.nn.functional.interpolate(image[None], size=size, scale_factor=scale_factor, mode='bilinear', image = torch.nn.functional.interpolate(
recompute_scale_factor=recompute_scale_factor, align_corners=False)[0] image[None],
size=size,
scale_factor=scale_factor,
mode="bilinear",
recompute_scale_factor=recompute_scale_factor,
align_corners=False,
)[0]
if target is None: if target is None:
return image, target return image, target
if "masks" in target: if "masks" in target:
mask = target["masks"] mask = target["masks"]
mask = torch.nn.functional.interpolate(mask[:, None].float(), size=size, scale_factor=scale_factor, mask = torch.nn.functional.interpolate(
recompute_scale_factor=recompute_scale_factor)[:, 0].byte() mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor
)[:, 0].byte()
target["masks"] = mask target["masks"] = mask
return image, target return image, target
...@@ -72,8 +83,15 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -72,8 +83,15 @@ class GeneralizedRCNNTransform(nn.Module):
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
""" """
def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float], def __init__(
size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None): self,
min_size: int,
max_size: int,
image_mean: List[float],
image_std: List[float],
size_divisible: int = 32,
fixed_size: Optional[Tuple[int, int]] = None,
):
super(GeneralizedRCNNTransform, self).__init__() super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)): if not isinstance(min_size, (list, tuple)):
min_size = (min_size,) min_size = (min_size,)
...@@ -84,10 +102,9 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -84,10 +102,9 @@ class GeneralizedRCNNTransform(nn.Module):
self.size_divisible = size_divisible self.size_divisible = size_divisible
self.fixed_size = fixed_size self.fixed_size = fixed_size
def forward(self, def forward(
images: List[Tensor], self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None
targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
images = [img for img in images] images = [img for img in images]
if targets is not None: if targets is not None:
# make a copy of targets to avoid modifying it in-place # make a copy of targets to avoid modifying it in-place
...@@ -106,8 +123,9 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -106,8 +123,9 @@ class GeneralizedRCNNTransform(nn.Module):
target_index = targets[i] if targets is not None else None target_index = targets[i] if targets is not None else None
if image.dim() != 3: if image.dim() != 3:
raise ValueError("images is expected to be a list of 3d tensors " raise ValueError(
"of shape [C, H, W], got {}".format(image.shape)) "images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape)
)
image = self.normalize(image) image = self.normalize(image)
image, target_index = self.resize(image, target_index) image, target_index = self.resize(image, target_index)
images[i] = image images[i] = image
...@@ -141,13 +159,14 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -141,13 +159,14 @@ class GeneralizedRCNNTransform(nn.Module):
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
is fixed. is fixed.
""" """
index = int(torch.empty(1).uniform_(0., float(len(k))).item()) index = int(torch.empty(1).uniform_(0.0, float(len(k))).item())
return k[index] return k[index]
def resize(self, def resize(
image: Tensor, self,
target: Optional[Dict[str, Tensor]] = None, image: Tensor,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: target: Optional[Dict[str, Tensor]] = None,
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
h, w = image.shape[-2:] h, w = image.shape[-2:]
if self.training: if self.training:
size = float(self.torch_choice(self.min_size)) size = float(self.torch_choice(self.min_size))
...@@ -220,11 +239,12 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -220,11 +239,12 @@ class GeneralizedRCNNTransform(nn.Module):
return batched_imgs return batched_imgs
def postprocess(self, def postprocess(
result: List[Dict[str, Tensor]], self,
image_shapes: List[Tuple[int, int]], result: List[Dict[str, Tensor]],
original_image_sizes: List[Tuple[int, int]] image_shapes: List[Tuple[int, int]],
) -> List[Dict[str, Tensor]]: original_image_sizes: List[Tuple[int, int]],
) -> List[Dict[str, Tensor]]:
if self.training: if self.training:
return result return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
...@@ -242,19 +262,20 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -242,19 +262,20 @@ class GeneralizedRCNNTransform(nn.Module):
return result return result
def __repr__(self) -> str: def __repr__(self) -> str:
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + "("
_indent = '\n ' _indent = "\n "
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size, format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(
self.max_size) _indent, self.min_size, self.max_size
format_string += '\n)' )
format_string += "\n)"
return format_string return format_string
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [ ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s, dtype=torch.float32, device=keypoints.device)
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
for s, s_orig in zip(new_size, original_size) for s, s_orig in zip(new_size, original_size)
] ]
ratio_h, ratio_w = ratios ratio_h, ratio_w = ratios
...@@ -271,8 +292,8 @@ def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List ...@@ -271,8 +292,8 @@ def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [ ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s, dtype=torch.float32, device=boxes.device)
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
for s, s_orig in zip(new_size, original_size) for s, s_orig in zip(new_size, original_size)
] ]
ratio_height, ratio_width = ratios ratio_height, ratio_width = ratios
......
import copy import copy
import math import math
import torch
from functools import partial from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
from torchvision.ops import StochasticDepth
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation from ..ops.misc import ConvNormActivation, SqueezeExcitation
from ._utils import _make_divisible from ._utils import _make_divisible
from torchvision.ops import StochasticDepth
__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3", __all__ = [
"efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"] "EfficientNet",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
"efficientnet_b3",
"efficientnet_b4",
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
]
model_urls = { model_urls = {
...@@ -32,10 +41,17 @@ model_urls = { ...@@ -32,10 +41,17 @@ model_urls = {
class MBConvConfig: class MBConvConfig:
# Stores information listed at Table 1 of the EfficientNet paper # Stores information listed at Table 1 of the EfficientNet paper
def __init__(self, def __init__(
expand_ratio: float, kernel: int, stride: int, self,
input_channels: int, out_channels: int, num_layers: int, expand_ratio: float,
width_mult: float, depth_mult: float) -> None: kernel: int,
stride: int,
input_channels: int,
out_channels: int,
num_layers: int,
width_mult: float,
depth_mult: float,
) -> None:
self.expand_ratio = expand_ratio self.expand_ratio = expand_ratio
self.kernel = kernel self.kernel = kernel
self.stride = stride self.stride = stride
...@@ -44,14 +60,14 @@ class MBConvConfig: ...@@ -44,14 +60,14 @@ class MBConvConfig:
self.num_layers = self.adjust_depth(num_layers, depth_mult) self.num_layers = self.adjust_depth(num_layers, depth_mult)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += 'expand_ratio={expand_ratio}' s += "expand_ratio={expand_ratio}"
s += ', kernel={kernel}' s += ", kernel={kernel}"
s += ', stride={stride}' s += ", stride={stride}"
s += ', input_channels={input_channels}' s += ", input_channels={input_channels}"
s += ', out_channels={out_channels}' s += ", out_channels={out_channels}"
s += ', num_layers={num_layers}' s += ", num_layers={num_layers}"
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
@staticmethod @staticmethod
...@@ -64,12 +80,17 @@ class MBConvConfig: ...@@ -64,12 +80,17 @@ class MBConvConfig:
class MBConv(nn.Module): class MBConv(nn.Module):
def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], def __init__(
se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: self,
cnf: MBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation,
) -> None:
super().__init__() super().__init__()
if not (1 <= cnf.stride <= 2): if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value') raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
...@@ -79,21 +100,39 @@ class MBConv(nn.Module): ...@@ -79,21 +100,39 @@ class MBConv(nn.Module):
# expand # expand
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels: if expanded_channels != cnf.input_channels:
layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1, layers.append(
norm_layer=norm_layer, activation_layer=activation_layer)) ConvNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise # depthwise
layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, layers.append(
stride=cnf.stride, groups=expanded_channels, ConvNormActivation(
norm_layer=norm_layer, activation_layer=activation_layer)) expanded_channels,
expanded_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
groups=expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# squeeze and excitation # squeeze and excitation
squeeze_channels = max(1, cnf.input_channels // 4) squeeze_channels = max(1, cnf.input_channels // 4)
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
# project # project
layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(
activation_layer=None)) ConvNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers) self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
...@@ -109,14 +148,14 @@ class MBConv(nn.Module): ...@@ -109,14 +148,14 @@ class MBConv(nn.Module):
class EfficientNet(nn.Module): class EfficientNet(nn.Module):
def __init__( def __init__(
self, self,
inverted_residual_setting: List[MBConvConfig], inverted_residual_setting: List[MBConvConfig],
dropout: float, dropout: float,
stochastic_depth_prob: float = 0.2, stochastic_depth_prob: float = 0.2,
num_classes: int = 1000, num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
EfficientNet main class EfficientNet main class
...@@ -133,8 +172,10 @@ class EfficientNet(nn.Module): ...@@ -133,8 +172,10 @@ class EfficientNet(nn.Module):
if not inverted_residual_setting: if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty") raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and elif not (
all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])): isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
if block is None: if block is None:
...@@ -147,8 +188,11 @@ class EfficientNet(nn.Module): ...@@ -147,8 +188,11 @@ class EfficientNet(nn.Module):
# building first layer # building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, layers.append(
activation_layer=nn.SiLU)) ConvNormActivation(
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
)
)
# building inverted residual blocks # building inverted residual blocks
total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting]) total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting])
...@@ -175,8 +219,15 @@ class EfficientNet(nn.Module): ...@@ -175,8 +219,15 @@ class EfficientNet(nn.Module):
# building last several layers # building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels lastconv_output_channels = 4 * lastconv_input_channels
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, layers.append(
norm_layer=norm_layer, activation_layer=nn.SiLU)) ConvNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.SiLU,
)
)
self.features = nn.Sequential(*layers) self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
...@@ -187,7 +238,7 @@ class EfficientNet(nn.Module): ...@@ -187,7 +238,7 @@ class EfficientNet(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out') nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
...@@ -232,7 +283,7 @@ def _efficientnet_model( ...@@ -232,7 +283,7 @@ def _efficientnet_model(
dropout: float, dropout: float,
pretrained: bool, pretrained: bool,
progress: bool, progress: bool,
**kwargs: Any **kwargs: Any,
) -> EfficientNet: ) -> EfficientNet:
model = EfficientNet(inverted_residual_setting, dropout, **kwargs) model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
if pretrained: if pretrained:
...@@ -318,8 +369,15 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -318,8 +369,15 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs)
return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress, return _efficientnet_model(
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) "efficientnet_b5",
inverted_residual_setting,
0.4,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -332,8 +390,15 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -332,8 +390,15 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs)
return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress, return _efficientnet_model(
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) "efficientnet_b6",
inverted_residual_setting,
0.5,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -346,5 +411,12 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -346,5 +411,12 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs)
return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress, return _efficientnet_model(
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) "efficientnet_b7",
inverted_residual_setting,
0.5,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
from typing import Dict, Callable, List, Union, Optional, Tuple
from collections import OrderedDict
import warnings
import re import re
import warnings
from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from typing import Dict, Callable, List, Union, Optional, Tuple
import torch import torch
from torch import nn
from torch import fx from torch import fx
from torch import nn
from torch.fx.graph_module import _copy_attr from torch.fx.graph_module import _copy_attr
__all__ = ['create_feature_extractor', 'get_graph_node_names'] __all__ = ["create_feature_extractor", "get_graph_node_names"]
class LeafModuleAwareTracer(fx.Tracer): class LeafModuleAwareTracer(fx.Tracer):
...@@ -20,10 +20,11 @@ class LeafModuleAwareTracer(fx.Tracer): ...@@ -20,10 +20,11 @@ class LeafModuleAwareTracer(fx.Tracer):
modules that are not to be traced through. The resulting graph ends up modules that are not to be traced through. The resulting graph ends up
having single nodes referencing calls to the leaf modules' forward methods. having single nodes referencing calls to the leaf modules' forward methods.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.leaf_modules = {} self.leaf_modules = {}
if 'leaf_modules' in kwargs: if "leaf_modules" in kwargs:
leaf_modules = kwargs.pop('leaf_modules') leaf_modules = kwargs.pop("leaf_modules")
self.leaf_modules = leaf_modules self.leaf_modules = leaf_modules
super(LeafModuleAwareTracer, self).__init__(*args, **kwargs) super(LeafModuleAwareTracer, self).__init__(*args, **kwargs)
...@@ -51,10 +52,11 @@ class NodePathTracer(LeafModuleAwareTracer): ...@@ -51,10 +52,11 @@ class NodePathTracer(LeafModuleAwareTracer):
- When a duplicate node name is encountered, a suffix of the form - When a duplicate node name is encountered, a suffix of the form
_{int} is added. The counter starts from 1. _{int} is added. The counter starts from 1.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(NodePathTracer, self).__init__(*args, **kwargs) super(NodePathTracer, self).__init__(*args, **kwargs)
# Track the qualified name of the Node being traced # Track the qualified name of the Node being traced
self.current_module_qualname = '' self.current_module_qualname = ""
# A map from FX Node to the qualified name\# # A map from FX Node to the qualified name\#
# NOTE: This is loosely like the "qualified name" mentioned in the # NOTE: This is loosely like the "qualified name" mentioned in the
# torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
...@@ -78,32 +80,31 @@ class NodePathTracer(LeafModuleAwareTracer): ...@@ -78,32 +80,31 @@ class NodePathTracer(LeafModuleAwareTracer):
if not self.is_leaf_module(m, module_qualname): if not self.is_leaf_module(m, module_qualname):
out = forward(*args, **kwargs) out = forward(*args, **kwargs)
return out return out
return self.create_proxy('call_module', module_qualname, args, kwargs) return self.create_proxy("call_module", module_qualname, args, kwargs)
finally: finally:
self.current_module_qualname = old_qualname self.current_module_qualname = old_qualname
def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, def create_proxy(
name=None, type_expr=None, *_) -> fx.proxy.Proxy: self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_
) -> fx.proxy.Proxy:
""" """
Override of `Tracer.create_proxy`. This override intercepts the recording Override of `Tracer.create_proxy`. This override intercepts the recording
of every operation and stores away the current traced module's qualified of every operation and stores away the current traced module's qualified
name in `node_to_qualname` name in `node_to_qualname`
""" """
proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
self.node_to_qualname[proxy.node] = self._get_node_qualname( self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node)
self.current_module_qualname, proxy.node)
return proxy return proxy
def _get_node_qualname( def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str:
self, module_qualname: str, node: fx.node.Node) -> str:
node_qualname = module_qualname node_qualname = module_qualname
if node.op != 'call_module': if node.op != "call_module":
# In this case module_qualname from torch.fx doesn't go all the # In this case module_qualname from torch.fx doesn't go all the
# way to the leaf function/op so we need to append it # way to the leaf function/op so we need to append it
if len(node_qualname) > 0: if len(node_qualname) > 0:
# Only append '.' if we are deeper than the top level module # Only append '.' if we are deeper than the top level module
node_qualname += '.' node_qualname += "."
node_qualname += str(node) node_qualname += str(node)
# Now we need to add an _{index} postfix on any repeated node names # Now we need to add an _{index} postfix on any repeated node names
...@@ -111,23 +112,22 @@ class NodePathTracer(LeafModuleAwareTracer): ...@@ -111,23 +112,22 @@ class NodePathTracer(LeafModuleAwareTracer):
# But for anything else, torch.fx already has a globally scoped # But for anything else, torch.fx already has a globally scoped
# _{index} postfix. But we want it locally (relative to direct parent) # _{index} postfix. But we want it locally (relative to direct parent)
# scoped. So first we need to undo the torch.fx postfix # scoped. So first we need to undo the torch.fx postfix
if re.match(r'.+_[0-9]+$', node_qualname) is not None: if re.match(r".+_[0-9]+$", node_qualname) is not None:
node_qualname = node_qualname.rsplit('_', 1)[0] node_qualname = node_qualname.rsplit("_", 1)[0]
# ... and now we add on our own postfix # ... and now we add on our own postfix
for existing_qualname in reversed(self.node_to_qualname.values()): for existing_qualname in reversed(self.node_to_qualname.values()):
# Check to see if existing_qualname is of the form # Check to see if existing_qualname is of the form
# {node_qualname} or {node_qualname}_{int} # {node_qualname} or {node_qualname}_{int}
if re.match(rf'{node_qualname}(_[0-9]+)?$', if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None:
existing_qualname) is not None: postfix = existing_qualname.replace(node_qualname, "")
postfix = existing_qualname.replace(node_qualname, '')
if len(postfix): if len(postfix):
# existing_qualname is of the form {node_qualname}_{int} # existing_qualname is of the form {node_qualname}_{int}
next_index = int(postfix[1:]) + 1 next_index = int(postfix[1:]) + 1
else: else:
# existing_qualname is of the form {node_qualname} # existing_qualname is of the form {node_qualname}
next_index = 1 next_index = 1
node_qualname += f'_{next_index}' node_qualname += f"_{next_index}"
break break
return node_qualname return node_qualname
...@@ -141,8 +141,7 @@ def _is_subseq(x, y): ...@@ -141,8 +141,7 @@ def _is_subseq(x, y):
return all(any(x_item == y_item for x_item in iter_x) for y_item in y) return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
def _warn_graph_differences( def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
""" """
Utility function for warning the user if there are differences between Utility function for warning the user if there are differences between
the train graph nodes and the eval graph nodes. the train graph nodes and the eval graph nodes.
...@@ -150,29 +149,32 @@ def _warn_graph_differences( ...@@ -150,29 +149,32 @@ def _warn_graph_differences(
train_nodes = list(train_tracer.node_to_qualname.values()) train_nodes = list(train_tracer.node_to_qualname.values())
eval_nodes = list(eval_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values())
if len(train_nodes) == len(eval_nodes) and all( if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)):
t == e for t, e in zip(train_nodes, eval_nodes)):
return return
suggestion_msg = ( suggestion_msg = (
"When choosing nodes for feature extraction, you may need to specify " "When choosing nodes for feature extraction, you may need to specify "
"output nodes for train and eval mode separately.") "output nodes for train and eval mode separately."
)
if _is_subseq(train_nodes, eval_nodes): if _is_subseq(train_nodes, eval_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in eval mode " msg = (
"are a subsequence of those obtained in train mode. ") "NOTE: The nodes obtained by tracing the model in eval mode "
"are a subsequence of those obtained in train mode. "
)
elif _is_subseq(eval_nodes, train_nodes): elif _is_subseq(eval_nodes, train_nodes):
msg = ("NOTE: The nodes obtained by tracing the model in train mode " msg = (
"are a subsequence of those obtained in eval mode. ") "NOTE: The nodes obtained by tracing the model in train mode "
"are a subsequence of those obtained in eval mode. "
)
else: else:
msg = ("The nodes obtained by tracing the model in train mode " msg = "The nodes obtained by tracing the model in train mode " "are different to those obtained in eval mode. "
"are different to those obtained in eval mode. ")
warnings.warn(msg + suggestion_msg) warnings.warn(msg + suggestion_msg)
def get_graph_node_names( def get_graph_node_names(
model: nn.Module, tracer_kwargs: Dict = {}, model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False
suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
""" """
Dev utility to return node names in order of execution. See note on node Dev utility to return node names in order of execution. See note on node
names under :func:`create_feature_extractor`. Useful for seeing which node names under :func:`create_feature_extractor`. Useful for seeing which node
...@@ -230,11 +232,10 @@ class DualGraphModule(fx.GraphModule): ...@@ -230,11 +232,10 @@ class DualGraphModule(fx.GraphModule):
- Copies submodules according to the nodes of both train and eval graphs. - Copies submodules according to the nodes of both train and eval graphs.
- Calling train(mode) switches between train graph and eval graph. - Calling train(mode) switches between train graph and eval graph.
""" """
def __init__(self,
root: torch.nn.Module, def __init__(
train_graph: fx.Graph, self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
eval_graph: fx.Graph, ):
class_name: str = 'GraphModule'):
""" """
Args: Args:
root (nn.Module): module from which the copied module hierarchy is root (nn.Module): module from which the copied module hierarchy is
...@@ -252,7 +253,7 @@ class DualGraphModule(fx.GraphModule): ...@@ -252,7 +253,7 @@ class DualGraphModule(fx.GraphModule):
# Copy all get_attr and call_module ops (indicated by BOTH train and # Copy all get_attr and call_module ops (indicated by BOTH train and
# eval graphs) # eval graphs)
for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
if node.op in ['get_attr', 'call_module']: if node.op in ["get_attr", "call_module"]:
assert isinstance(node.target, str) assert isinstance(node.target, str)
_copy_attr(root, self, node.target) _copy_attr(root, self, node.target)
...@@ -266,10 +267,11 @@ class DualGraphModule(fx.GraphModule): ...@@ -266,10 +267,11 @@ class DualGraphModule(fx.GraphModule):
# Locally defined Tracers are not pickleable. This is needed because torch.package will # Locally defined Tracers are not pickleable. This is needed because torch.package will
# serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
# to re-create the Graph during deserialization. # to re-create the Graph during deserialization.
assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ assert (
"Train mode and eval mode should use the same tracer class" self.eval_graph._tracer_cls == self.train_graph._tracer_cls
), "Train mode and eval mode should use the same tracer class"
self._tracer_cls = None self._tracer_cls = None
if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__: if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__:
self._tracer_cls = self.graph._tracer_cls self._tracer_cls = self.graph._tracer_cls
def train(self, mode=True): def train(self, mode=True):
...@@ -288,12 +290,13 @@ class DualGraphModule(fx.GraphModule): ...@@ -288,12 +290,13 @@ class DualGraphModule(fx.GraphModule):
def create_feature_extractor( def create_feature_extractor(
model: nn.Module, model: nn.Module,
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {}, tracer_kwargs: Dict = {},
suppress_diff_warning: bool = False) -> fx.GraphModule: suppress_diff_warning: bool = False,
) -> fx.GraphModule:
""" """
Creates a new graph module that returns intermediate nodes from a given Creates a new graph module that returns intermediate nodes from a given
model as dictionary with user specified keys as strings, and the requested model as dictionary with user specified keys as strings, and the requested
...@@ -396,18 +399,17 @@ def create_feature_extractor( ...@@ -396,18 +399,17 @@ def create_feature_extractor(
""" """
is_training = model.training is_training = model.training
assert any(arg is not None for arg in [ assert any(arg is not None for arg in [return_nodes, train_return_nodes, eval_return_nodes]), (
return_nodes, train_return_nodes, eval_return_nodes]), ( "Either `return_nodes` or `train_return_nodes` and " "`eval_return_nodes` together, should be specified"
"Either `return_nodes` or `train_return_nodes` and " )
"`eval_return_nodes` together, should be specified")
assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), (
("If any of `train_return_nodes` and `eval_return_nodes` are " "If any of `train_return_nodes` and `eval_return_nodes` are " "specified, then both should be specified"
"specified, then both should be specified") )
assert ((return_nodes is None) ^ (train_return_nodes is None)), \ assert (return_nodes is None) ^ (train_return_nodes is None), (
("If `train_return_nodes` and `eval_return_nodes` are specified, " "If `train_return_nodes` and `eval_return_nodes` are specified, " "then both should be specified"
"then both should be specified") )
# Put *_return_nodes into Dict[str, str] format # Put *_return_nodes into Dict[str, str] format
def to_strdict(n) -> Dict[str, str]: def to_strdict(n) -> Dict[str, str]:
...@@ -426,45 +428,42 @@ def create_feature_extractor( ...@@ -426,45 +428,42 @@ def create_feature_extractor(
# Repeat the tracing and graph rewriting for train and eval mode # Repeat the tracing and graph rewriting for train and eval mode
tracers = {} tracers = {}
graphs = {} graphs = {}
mode_return_nodes: Dict[str, Dict[str, str]] = { mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes}
'train': train_return_nodes, for mode in ["train", "eval"]:
'eval': eval_return_nodes if mode == "train":
}
for mode in ['train', 'eval']:
if mode == 'train':
model.train() model.train()
elif mode == 'eval': elif mode == "eval":
model.eval() model.eval()
# Instantiate our NodePathTracer and use that to trace the model # Instantiate our NodePathTracer and use that to trace the model
tracer = NodePathTracer(**tracer_kwargs) tracer = NodePathTracer(**tracer_kwargs)
graph = tracer.trace(model) graph = tracer.trace(model)
name = model.__class__.__name__ if isinstance( name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
model, nn.Module) else model.__name__
graph_module = fx.GraphModule(tracer.root, graph, name) graph_module = fx.GraphModule(tracer.root, graph, name)
available_nodes = list(tracer.node_to_qualname.values()) available_nodes = list(tracer.node_to_qualname.values())
# FIXME We don't know if we should expect this to happen # FIXME We don't know if we should expect this to happen
assert len(set(available_nodes)) == len(available_nodes), \ assert len(set(available_nodes)) == len(
"There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" available_nodes
), "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
# Check that all outputs in return_nodes are present in the model # Check that all outputs in return_nodes are present in the model
for query in mode_return_nodes[mode].keys(): for query in mode_return_nodes[mode].keys():
# To check if a query is available we need to check that at least # To check if a query is available we need to check that at least
# one of the available names starts with it up to a . # one of the available names starts with it up to a .
if not any([re.match(rf'^{query}(\.|$)', n) is not None if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]):
for n in available_nodes]):
raise ValueError( raise ValueError(
f"node: '{query}' is not present in model. Hint: use " f"node: '{query}' is not present in model. Hint: use "
"`get_graph_node_names` to make sure the " "`get_graph_node_names` to make sure the "
"`return_nodes` you specified are present. It may even " "`return_nodes` you specified are present. It may even "
"be that you need to specify `train_return_nodes` and " "be that you need to specify `train_return_nodes` and "
"`eval_return_nodes` separately.") "`eval_return_nodes` separately."
)
# Remove existing output nodes (train mode) # Remove existing output nodes (train mode)
orig_output_nodes = [] orig_output_nodes = []
for n in reversed(graph_module.graph.nodes): for n in reversed(graph_module.graph.nodes):
if n.op == 'output': if n.op == "output":
orig_output_nodes.append(n) orig_output_nodes.append(n)
assert len(orig_output_nodes) assert len(orig_output_nodes)
for n in orig_output_nodes: for n in orig_output_nodes:
...@@ -482,8 +481,8 @@ def create_feature_extractor( ...@@ -482,8 +481,8 @@ def create_feature_extractor(
# - When packing outputs into a named tuple like in InceptionV3 # - When packing outputs into a named tuple like in InceptionV3
continue continue
for query in mode_return_nodes[mode]: for query in mode_return_nodes[mode]:
depth = query.count('.') depth = query.count(".")
if '.'.join(module_qualname.split('.')[:depth + 1]) == query: if ".".join(module_qualname.split(".")[: depth + 1]) == query:
output_nodes[mode_return_nodes[mode][query]] = n output_nodes[mode_return_nodes[mode][query]] = n
mode_return_nodes[mode].pop(query) mode_return_nodes[mode].pop(query)
break break
...@@ -504,11 +503,10 @@ def create_feature_extractor( ...@@ -504,11 +503,10 @@ def create_feature_extractor(
# Warn user if there are any discrepancies between the graphs of the # Warn user if there are any discrepancies between the graphs of the
# train and eval modes # train and eval modes
if not suppress_diff_warning: if not suppress_diff_warning:
_warn_graph_differences(tracers['train'], tracers['eval']) _warn_graph_differences(tracers["train"], tracers["eval"])
# Build the final graph module # Build the final graph module
graph_module = DualGraphModule( graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name)
model, graphs['train'], graphs['eval'], class_name=name)
# Restore original training mode # Restore original training mode
model.train(is_training) model.train(is_training)
......
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from typing import Optional, Tuple, List, Callable, Any
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Optional, Tuple, List, Callable, Any
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"] __all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"]
model_urls = { model_urls = {
# GoogLeNet ported from TensorFlow # GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
} }
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor], GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
'aux_logits1': Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ... # Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat # _GoogLeNetOutputs set here for backwards compat
...@@ -37,19 +38,19 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -37,19 +38,19 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
was trained on ImageNet. Default: *False* was trained on ImageNet. Default: *False*
""" """
if pretrained: if pretrained:
if 'transform_input' not in kwargs: if "transform_input" not in kwargs:
kwargs['transform_input'] = True kwargs["transform_input"] = True
if 'aux_logits' not in kwargs: if "aux_logits" not in kwargs:
kwargs['aux_logits'] = False kwargs["aux_logits"] = False
if kwargs['aux_logits']: if kwargs["aux_logits"]:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' warnings.warn(
'so make sure to train them') "auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them"
original_aux_logits = kwargs['aux_logits'] )
kwargs['aux_logits'] = True original_aux_logits = kwargs["aux_logits"]
kwargs['init_weights'] = False kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = GoogLeNet(**kwargs) model = GoogLeNet(**kwargs)
state_dict = load_state_dict_from_url(model_urls['googlenet'], state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
...@@ -61,7 +62,7 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ...@@ -61,7 +62,7 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
class GoogLeNet(nn.Module): class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input'] __constants__ = ["aux_logits", "transform_input"]
def __init__( def __init__(
self, self,
...@@ -69,15 +70,18 @@ class GoogLeNet(nn.Module): ...@@ -69,15 +70,18 @@ class GoogLeNet(nn.Module):
aux_logits: bool = True, aux_logits: bool = True,
transform_input: bool = False, transform_input: bool = False,
init_weights: Optional[bool] = None, init_weights: Optional[bool] = None,
blocks: Optional[List[Callable[..., nn.Module]]] = None blocks: Optional[List[Callable[..., nn.Module]]] = None,
) -> None: ) -> None:
super(GoogLeNet, self).__init__() super(GoogLeNet, self).__init__()
if blocks is None: if blocks is None:
blocks = [BasicConv2d, Inception, InceptionAux] blocks = [BasicConv2d, Inception, InceptionAux]
if init_weights is None: if init_weights is None:
warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of ' warnings.warn(
'torchvision. If you wish to keep the old behavior (which leads to long initialization times' "The default weight initialization of GoogleNet will be changed in future releases of "
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
" due to scipy/scipy#11299), please set init_weights=True.",
FutureWarning,
)
init_weights = True init_weights = True
assert len(blocks) == 3 assert len(blocks) == 3
conv_block = blocks[0] conv_block = blocks[0]
...@@ -197,7 +201,7 @@ class GoogLeNet(nn.Module): ...@@ -197,7 +201,7 @@ class GoogLeNet(nn.Module):
if self.training and self.aux_logits: if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1) return _GoogLeNetOutputs(x, aux2, aux1)
else: else:
return x # type: ignore[return-value] return x # type: ignore[return-value]
def forward(self, x: Tensor) -> GoogLeNetOutputs: def forward(self, x: Tensor) -> GoogLeNetOutputs:
x = self._transform_input(x) x = self._transform_input(x)
...@@ -212,7 +216,6 @@ class GoogLeNet(nn.Module): ...@@ -212,7 +216,6 @@ class GoogLeNet(nn.Module):
class Inception(nn.Module): class Inception(nn.Module):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -222,7 +225,7 @@ class Inception(nn.Module): ...@@ -222,7 +225,7 @@ class Inception(nn.Module):
ch5x5red: int, ch5x5red: int,
ch5x5: int, ch5x5: int,
pool_proj: int, pool_proj: int,
conv_block: Optional[Callable[..., nn.Module]] = None conv_block: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super(Inception, self).__init__() super(Inception, self).__init__()
if conv_block is None: if conv_block is None:
...@@ -230,20 +233,19 @@ class Inception(nn.Module): ...@@ -230,20 +233,19 @@ class Inception(nn.Module):
self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1) self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
self.branch2 = nn.Sequential( self.branch2 = nn.Sequential(
conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
) )
self.branch3 = nn.Sequential( self.branch3 = nn.Sequential(
conv_block(in_channels, ch5x5red, kernel_size=1), conv_block(in_channels, ch5x5red, kernel_size=1),
# Here, kernel_size=3 instead of kernel_size=5 is a known bug. # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
# Please see https://github.com/pytorch/vision/issues/906 for details. # Please see https://github.com/pytorch/vision/issues/906 for details.
conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1) conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
) )
self.branch4 = nn.Sequential( self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
conv_block(in_channels, pool_proj, kernel_size=1) conv_block(in_channels, pool_proj, kernel_size=1),
) )
def _forward(self, x: Tensor) -> List[Tensor]: def _forward(self, x: Tensor) -> List[Tensor]:
...@@ -261,12 +263,8 @@ class Inception(nn.Module): ...@@ -261,12 +263,8 @@ class Inception(nn.Module):
class InceptionAux(nn.Module): class InceptionAux(nn.Module):
def __init__( def __init__(
self, self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None: ) -> None:
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
if conv_block is None: if conv_block is None:
...@@ -295,13 +293,7 @@ class InceptionAux(nn.Module): ...@@ -295,13 +293,7 @@ class InceptionAux(nn.Module):
class BasicConv2d(nn.Module): class BasicConv2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
def __init__(
self,
in_channels: int,
out_channels: int,
**kwargs: Any
) -> None:
super(BasicConv2d, self).__init__() super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
......
from collections import namedtuple
import warnings import warnings
from collections import namedtuple
from typing import Callable, Any, Optional, Tuple, List
import torch import torch
from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs'] __all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"]
model_urls = { model_urls = {
# Inception v3 ported from TensorFlow # Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth', "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
} }
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]} InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
# Script annotations failed with _GoogleNetOutputs = namedtuple ... # Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat # _InceptionOutputs set here for backwards compat
...@@ -41,17 +43,16 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -41,17 +43,16 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any)
was trained on ImageNet. Default: *False* was trained on ImageNet. Default: *False*
""" """
if pretrained: if pretrained:
if 'transform_input' not in kwargs: if "transform_input" not in kwargs:
kwargs['transform_input'] = True kwargs["transform_input"] = True
if 'aux_logits' in kwargs: if "aux_logits" in kwargs:
original_aux_logits = kwargs['aux_logits'] original_aux_logits = kwargs["aux_logits"]
kwargs['aux_logits'] = True kwargs["aux_logits"] = True
else: else:
original_aux_logits = True original_aux_logits = True
kwargs['init_weights'] = False # we are loading weights from a pretrained model kwargs["init_weights"] = False # we are loading weights from a pretrained model
model = Inception3(**kwargs) model = Inception3(**kwargs)
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
...@@ -62,25 +63,24 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -62,25 +63,24 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any)
class Inception3(nn.Module): class Inception3(nn.Module):
def __init__( def __init__(
self, self,
num_classes: int = 1000, num_classes: int = 1000,
aux_logits: bool = True, aux_logits: bool = True,
transform_input: bool = False, transform_input: bool = False,
inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
init_weights: Optional[bool] = None init_weights: Optional[bool] = None,
) -> None: ) -> None:
super(Inception3, self).__init__() super(Inception3, self).__init__()
if inception_blocks is None: if inception_blocks is None:
inception_blocks = [ inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
if init_weights is None: if init_weights is None:
warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ' warnings.warn(
'torchvision. If you wish to keep the old behavior (which leads to long initialization times' "The default weight initialization of inception_v3 will be changed in future releases of "
' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
" due to scipy/scipy#11299), please set init_weights=True.",
FutureWarning,
)
init_weights = True init_weights = True
assert len(inception_blocks) == 7 assert len(inception_blocks) == 7
conv_block = inception_blocks[0] conv_block = inception_blocks[0]
...@@ -120,7 +120,7 @@ class Inception3(nn.Module): ...@@ -120,7 +120,7 @@ class Inception3(nn.Module):
if init_weights: if init_weights:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = float(m.stddev) if hasattr(m, 'stddev') else 0.1 # type: ignore stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2) torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
...@@ -208,12 +208,8 @@ class Inception3(nn.Module): ...@@ -208,12 +208,8 @@ class Inception3(nn.Module):
class InceptionA(nn.Module): class InceptionA(nn.Module):
def __init__( def __init__(
self, self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
in_channels: int,
pool_features: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None: ) -> None:
super(InceptionA, self).__init__() super(InceptionA, self).__init__()
if conv_block is None: if conv_block is None:
...@@ -251,12 +247,7 @@ class InceptionA(nn.Module): ...@@ -251,12 +247,7 @@ class InceptionA(nn.Module):
class InceptionB(nn.Module): class InceptionB(nn.Module):
def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionB, self).__init__() super(InceptionB, self).__init__()
if conv_block is None: if conv_block is None:
conv_block = BasicConv2d conv_block = BasicConv2d
...@@ -284,12 +275,8 @@ class InceptionB(nn.Module): ...@@ -284,12 +275,8 @@ class InceptionB(nn.Module):
class InceptionC(nn.Module): class InceptionC(nn.Module):
def __init__( def __init__(
self, self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
in_channels: int,
channels_7x7: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None: ) -> None:
super(InceptionC, self).__init__() super(InceptionC, self).__init__()
if conv_block is None: if conv_block is None:
...@@ -334,12 +321,7 @@ class InceptionC(nn.Module): ...@@ -334,12 +321,7 @@ class InceptionC(nn.Module):
class InceptionD(nn.Module): class InceptionD(nn.Module):
def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionD, self).__init__() super(InceptionD, self).__init__()
if conv_block is None: if conv_block is None:
conv_block = BasicConv2d conv_block = BasicConv2d
...@@ -370,12 +352,7 @@ class InceptionD(nn.Module): ...@@ -370,12 +352,7 @@ class InceptionD(nn.Module):
class InceptionE(nn.Module): class InceptionE(nn.Module):
def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
def __init__(
self,
in_channels: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None:
super(InceptionE, self).__init__() super(InceptionE, self).__init__()
if conv_block is None: if conv_block is None:
conv_block = BasicConv2d conv_block = BasicConv2d
...@@ -422,12 +399,8 @@ class InceptionE(nn.Module): ...@@ -422,12 +399,8 @@ class InceptionE(nn.Module):
class InceptionAux(nn.Module): class InceptionAux(nn.Module):
def __init__( def __init__(
self, self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
in_channels: int,
num_classes: int,
conv_block: Optional[Callable[..., nn.Module]] = None
) -> None: ) -> None:
super(InceptionAux, self).__init__() super(InceptionAux, self).__init__()
if conv_block is None: if conv_block is None:
...@@ -457,13 +430,7 @@ class InceptionAux(nn.Module): ...@@ -457,13 +430,7 @@ class InceptionAux(nn.Module):
class BasicConv2d(nn.Module): class BasicConv2d(nn.Module):
def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
def __init__(
self,
in_channels: int,
out_channels: int,
**kwargs: Any
) -> None:
super(BasicConv2d, self).__init__() super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
......
import warnings import warnings
from typing import Any, Dict, List
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Any, Dict, List
__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] __all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"]
_MODEL_URLS = { _MODEL_URLS = {
"mnasnet0_5": "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_75": None, "mnasnet0_75": None,
"mnasnet1_0": "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", "mnasnet1_3": None,
"mnasnet1_3": None
} }
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
...@@ -23,34 +22,27 @@ _BN_MOMENTUM = 1 - 0.9997 ...@@ -23,34 +22,27 @@ _BN_MOMENTUM = 1 - 0.9997
class _InvertedResidual(nn.Module): class _InvertedResidual(nn.Module):
def __init__( def __init__(
self, self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
in_ch: int,
out_ch: int,
kernel_size: int,
stride: int,
expansion_factor: int,
bn_momentum: float = 0.1
) -> None: ) -> None:
super(_InvertedResidual, self).__init__() super(_InvertedResidual, self).__init__()
assert stride in [1, 2] assert stride in [1, 2]
assert kernel_size in [3, 5] assert kernel_size in [3, 5]
mid_ch = in_ch * expansion_factor mid_ch = in_ch * expansion_factor
self.apply_residual = (in_ch == out_ch and stride == 1) self.apply_residual = in_ch == out_ch and stride == 1
self.layers = nn.Sequential( self.layers = nn.Sequential(
# Pointwise # Pointwise
nn.Conv2d(in_ch, mid_ch, 1, bias=False), nn.Conv2d(in_ch, mid_ch, 1, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Depthwise # Depthwise
nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
stride=stride, groups=mid_ch, bias=False),
nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Linear pointwise. Note that there's no activation. # Linear pointwise. Note that there's no activation.
nn.Conv2d(mid_ch, out_ch, 1, bias=False), nn.Conv2d(mid_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch, momentum=bn_momentum)) nn.BatchNorm2d(out_ch, momentum=bn_momentum),
)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
if self.apply_residual: if self.apply_residual:
...@@ -59,39 +51,37 @@ class _InvertedResidual(nn.Module): ...@@ -59,39 +51,37 @@ class _InvertedResidual(nn.Module):
return self.layers(input) return self.layers(input)
def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, def _stack(
bn_momentum: float) -> nn.Sequential: in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
""" Creates a stack of inverted residuals. """ ) -> nn.Sequential:
"""Creates a stack of inverted residuals."""
assert repeats >= 1 assert repeats >= 1
# First one has no skip, because feature map size changes. # First one has no skip, because feature map size changes.
first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
bn_momentum=bn_momentum)
remaining = [] remaining = []
for _ in range(1, repeats): for _ in range(1, repeats):
remaining.append( remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
bn_momentum=bn_momentum))
return nn.Sequential(first, *remaining) return nn.Sequential(first, *remaining)
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
""" Asymmetric rounding to make `val` divisible by `divisor`. With default """Asymmetric rounding to make `val` divisible by `divisor`. With default
bias, will round up, unless the number is no more than 10% greater than the bias, will round up, unless the number is no more than 10% greater than the
smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
assert 0.0 < round_up_bias < 1.0 assert 0.0 < round_up_bias < 1.0
new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
return new_val if new_val >= round_up_bias * val else new_val + divisor return new_val if new_val >= round_up_bias * val else new_val + divisor
def _get_depths(alpha: float) -> List[int]: def _get_depths(alpha: float) -> List[int]:
""" Scales tensor depths as in reference MobileNet code, prefers rouding up """Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """ rather than down."""
depths = [32, 16, 24, 40, 80, 96, 192, 320] depths = [32, 16, 24, 40, 80, 96, 192, 320]
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
class MNASNet(torch.nn.Module): class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model. implements the B1 variant of the model.
>>> model = MNASNet(1.0, num_classes=1000) >>> model = MNASNet(1.0, num_classes=1000)
>>> x = torch.rand(1, 3, 224, 224) >>> x = torch.rand(1, 3, 224, 224)
...@@ -101,15 +91,11 @@ class MNASNet(torch.nn.Module): ...@@ -101,15 +91,11 @@ class MNASNet(torch.nn.Module):
>>> y.nelement() >>> y.nelement()
1000 1000
""" """
# Version 2 adds depth scaling in the initial stages of the network. # Version 2 adds depth scaling in the initial stages of the network.
_version = 2 _version = 2
def __init__( def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
self,
alpha: float,
num_classes: int = 1000,
dropout: float = 0.2
) -> None:
super(MNASNet, self).__init__() super(MNASNet, self).__init__()
assert alpha > 0.0 assert alpha > 0.0
self.alpha = alpha self.alpha = alpha
...@@ -121,8 +107,7 @@ class MNASNet(torch.nn.Module): ...@@ -121,8 +107,7 @@ class MNASNet(torch.nn.Module):
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# Depthwise separable, no skip. # Depthwise separable, no skip.
nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
groups=depths[0], bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
...@@ -140,8 +125,7 @@ class MNASNet(torch.nn.Module): ...@@ -140,8 +125,7 @@ class MNASNet(torch.nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
] ]
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
nn.Linear(1280, num_classes))
self._initialize_weights() self._initialize_weights()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
...@@ -153,20 +137,26 @@ class MNASNet(torch.nn.Module): ...@@ -153,20 +137,26 @@ class MNASNet(torch.nn.Module):
def _initialize_weights(self) -> None: def _initialize_weights(self) -> None:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
nonlinearity="relu")
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight) nn.init.ones_(m.weight)
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode="fan_out", nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
nonlinearity="sigmoid")
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, def _load_from_state_dict(
missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: self,
state_dict: Dict,
prefix: str,
local_metadata: Dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
version = local_metadata.get("version", None) version = local_metadata.get("version", None)
assert version in [1, 2] assert version in [1, 2]
...@@ -180,8 +170,7 @@ class MNASNet(torch.nn.Module): ...@@ -180,8 +170,7 @@ class MNASNet(torch.nn.Module):
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
...@@ -199,20 +188,19 @@ class MNASNet(torch.nn.Module): ...@@ -199,20 +188,19 @@ class MNASNet(torch.nn.Module):
"This checkpoint will load and work as before, but " "This checkpoint will load and work as before, but "
"you may want to upgrade by training a newer model or " "you may want to upgrade by training a newer model or "
"transfer learning from an updated ImageNet checkpoint.", "transfer learning from an updated ImageNet checkpoint.",
UserWarning) UserWarning,
)
super(MNASNet, self)._load_from_state_dict( super(MNASNet, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
unexpected_keys, error_msgs) )
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError( raise ValueError("No checkpoint is available for model type {}".format(model_name))
"No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name] checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict( model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
load_state_dict_from_url(checkpoint_url, progress=progress))
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
......
import torch
import warnings import warnings
from functools import partial from functools import partial
from torch import nn from typing import Callable, Any, Optional, List
import torch
from torch import Tensor from torch import Tensor
from torch import nn
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation from ..ops.misc import ConvNormActivation
from ._utils import _make_divisible from ._utils import _make_divisible
from typing import Callable, Any, Optional, List
__all__ = ['MobileNetV2', 'mobilenet_v2'] __all__ = ["MobileNetV2", "mobilenet_v2"]
model_urls = { model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
} }
...@@ -23,7 +24,9 @@ class _DeprecatedConvBNAct(ConvNormActivation): ...@@ -23,7 +24,9 @@ class _DeprecatedConvBNAct(ConvNormActivation):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
warnings.warn( warnings.warn(
"The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. " "The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning) "Use torchvision.ops.misc.ConvNormActivation instead.",
FutureWarning,
)
if kwargs.get("norm_layer", None) is None: if kwargs.get("norm_layer", None) is None:
kwargs["norm_layer"] = nn.BatchNorm2d kwargs["norm_layer"] = nn.BatchNorm2d
if kwargs.get("activation_layer", None) is None: if kwargs.get("activation_layer", None) is None:
...@@ -37,12 +40,7 @@ ConvBNActivation = _DeprecatedConvBNAct ...@@ -37,12 +40,7 @@ ConvBNActivation = _DeprecatedConvBNAct
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
def __init__( def __init__(
self, self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
inp: int,
oup: int,
stride: int,
expand_ratio: int,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None: ) -> None:
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
self.stride = stride self.stride = stride
...@@ -57,16 +55,25 @@ class InvertedResidual(nn.Module): ...@@ -57,16 +55,25 @@ class InvertedResidual(nn.Module):
layers: List[nn.Module] = [] layers: List[nn.Module] = []
if expand_ratio != 1: if expand_ratio != 1:
# pw # pw
layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, layers.append(
activation_layer=nn.ReLU6)) ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
layers.extend([ )
# dw layers.extend(
ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, [
activation_layer=nn.ReLU6), # dw
# pw-linear ConvNormActivation(
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), hidden_dim,
norm_layer(oup), hidden_dim,
]) stride=stride,
groups=hidden_dim,
norm_layer=norm_layer,
activation_layer=nn.ReLU6,
),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
norm_layer(oup),
]
)
self.conv = nn.Sequential(*layers) self.conv = nn.Sequential(*layers)
self.out_channels = oup self.out_channels = oup
self._is_cn = stride > 1 self._is_cn = stride > 1
...@@ -86,7 +93,7 @@ class MobileNetV2(nn.Module): ...@@ -86,7 +93,7 @@ class MobileNetV2(nn.Module):
inverted_residual_setting: Optional[List[List[int]]] = None, inverted_residual_setting: Optional[List[List[int]]] = None,
round_nearest: int = 8, round_nearest: int = 8,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
""" """
MobileNet V2 main class MobileNet V2 main class
...@@ -126,14 +133,17 @@ class MobileNetV2(nn.Module): ...@@ -126,14 +133,17 @@ class MobileNetV2(nn.Module):
# only check the first element, assuming user knows t,c,n,s are required # only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty " raise ValueError(
"or a 4-element list, got {}".format(inverted_residual_setting)) "inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting)
)
# building first layer # building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest) input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, features: List[nn.Module] = [
activation_layer=nn.ReLU6)] ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
]
# building inverted residual blocks # building inverted residual blocks
for t, c, n, s in inverted_residual_setting: for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest) output_channel = _make_divisible(c * width_mult, round_nearest)
...@@ -142,8 +152,11 @@ class MobileNetV2(nn.Module): ...@@ -142,8 +152,11 @@ class MobileNetV2(nn.Module):
features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
input_channel = output_channel input_channel = output_channel
# building last several layers # building last several layers
features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, features.append(
activation_layer=nn.ReLU6)) ConvNormActivation(
input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
)
)
# make it nn.Sequential # make it nn.Sequential
self.features = nn.Sequential(*features) self.features = nn.Sequential(*features)
...@@ -156,7 +169,7 @@ class MobileNetV2(nn.Module): ...@@ -156,7 +169,7 @@ class MobileNetV2(nn.Module):
# weight initialization # weight initialization
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out') nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
...@@ -191,7 +204,6 @@ def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -191,7 +204,6 @@ def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any)
""" """
model = MobileNetV2(**kwargs) model = MobileNetV2(**kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
import warnings import warnings
import torch
from functools import partial from functools import partial
from torch import nn, Tensor
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
import torch
from torch import nn, Tensor
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer
from ._utils import _make_divisible from ._utils import _make_divisible
...@@ -20,22 +20,34 @@ model_urls = { ...@@ -20,22 +20,34 @@ model_urls = {
class SqueezeExcitation(SElayer): class SqueezeExcitation(SElayer):
"""DEPRECATED """DEPRECATED"""
"""
def __init__(self, input_channels: int, squeeze_factor: int = 4): def __init__(self, input_channels: int, squeeze_factor: int = 4):
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
self.relu = self.activation self.relu = self.activation
delattr(self, 'activation') delattr(self, "activation")
warnings.warn( warnings.warn(
"This SqueezeExcitation class is deprecated and will be removed in future versions. " "This SqueezeExcitation class is deprecated and will be removed in future versions. "
"Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning) "Use torchvision.ops.misc.SqueezeExcitation instead.",
FutureWarning,
)
class InvertedResidualConfig: class InvertedResidualConfig:
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, def __init__(
activation: str, stride: int, dilation: int, width_mult: float): self,
input_channels: int,
kernel: int,
expanded_channels: int,
out_channels: int,
use_se: bool,
activation: str,
stride: int,
dilation: int,
width_mult: float,
):
self.input_channels = self.adjust_channels(input_channels, width_mult) self.input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel self.kernel = kernel
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
...@@ -52,11 +64,15 @@ class InvertedResidualConfig: ...@@ -52,11 +64,15 @@ class InvertedResidualConfig:
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
# Implemented as described at section 5 of MobileNetV3 paper # Implemented as described at section 5 of MobileNetV3 paper
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], def __init__(
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)): self,
cnf: InvertedResidualConfig,
norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
):
super().__init__() super().__init__()
if not (1 <= cnf.stride <= 2): if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value') raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
...@@ -65,21 +81,40 @@ class InvertedResidual(nn.Module): ...@@ -65,21 +81,40 @@ class InvertedResidual(nn.Module):
# expand # expand
if cnf.expanded_channels != cnf.input_channels: if cnf.expanded_channels != cnf.input_channels:
layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, layers.append(
norm_layer=norm_layer, activation_layer=activation_layer)) ConvNormActivation(
cnf.input_channels,
cnf.expanded_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# depthwise # depthwise
stride = 1 if cnf.dilation > 1 else cnf.stride stride = 1 if cnf.dilation > 1 else cnf.stride
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, layers.append(
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, ConvNormActivation(
norm_layer=norm_layer, activation_layer=activation_layer)) cnf.expanded_channels,
cnf.expanded_channels,
kernel_size=cnf.kernel,
stride=stride,
dilation=cnf.dilation,
groups=cnf.expanded_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
if cnf.use_se: if cnf.use_se:
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
# project # project
layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, layers.append(
activation_layer=None)) ConvNormActivation(
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
self.block = nn.Sequential(*layers) self.block = nn.Sequential(*layers)
self.out_channels = cnf.out_channels self.out_channels = cnf.out_channels
...@@ -93,15 +128,14 @@ class InvertedResidual(nn.Module): ...@@ -93,15 +128,14 @@ class InvertedResidual(nn.Module):
class MobileNetV3(nn.Module): class MobileNetV3(nn.Module):
def __init__( def __init__(
self, self,
inverted_residual_setting: List[InvertedResidualConfig], inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int, last_channel: int,
num_classes: int = 1000, num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None, block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any **kwargs: Any,
) -> None: ) -> None:
""" """
MobileNet V3 main class MobileNet V3 main class
...@@ -117,8 +151,10 @@ class MobileNetV3(nn.Module): ...@@ -117,8 +151,10 @@ class MobileNetV3(nn.Module):
if not inverted_residual_setting: if not inverted_residual_setting:
raise ValueError("The inverted_residual_setting should not be empty") raise ValueError("The inverted_residual_setting should not be empty")
elif not (isinstance(inverted_residual_setting, Sequence) and elif not (
all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
):
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
if block is None: if block is None:
...@@ -131,8 +167,16 @@ class MobileNetV3(nn.Module): ...@@ -131,8 +167,16 @@ class MobileNetV3(nn.Module):
# building first layer # building first layer
firstconv_output_channels = inverted_residual_setting[0].input_channels firstconv_output_channels = inverted_residual_setting[0].input_channels
layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, layers.append(
activation_layer=nn.Hardswish)) ConvNormActivation(
3,
firstconv_output_channels,
kernel_size=3,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
)
# building inverted residual blocks # building inverted residual blocks
for cnf in inverted_residual_setting: for cnf in inverted_residual_setting:
...@@ -141,8 +185,15 @@ class MobileNetV3(nn.Module): ...@@ -141,8 +185,15 @@ class MobileNetV3(nn.Module):
# building last several layers # building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 6 * lastconv_input_channels lastconv_output_channels = 6 * lastconv_input_channels
layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, layers.append(
norm_layer=norm_layer, activation_layer=nn.Hardswish)) ConvNormActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
)
self.features = nn.Sequential(*layers) self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
...@@ -155,7 +206,7 @@ class MobileNetV3(nn.Module): ...@@ -155,7 +206,7 @@ class MobileNetV3(nn.Module):
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out') nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None: if m.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
...@@ -179,8 +230,9 @@ class MobileNetV3(nn.Module): ...@@ -179,8 +230,9 @@ class MobileNetV3(nn.Module):
return self._forward_impl(x) return self._forward_impl(x)
def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, def _mobilenet_v3_conf(
**kwargs: Any): arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
):
reduce_divider = 2 if reduced_tail else 1 reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1 dilation = 2 if dilated else 1
...@@ -233,7 +285,7 @@ def _mobilenet_v3_model( ...@@ -233,7 +285,7 @@ def _mobilenet_v3_model(
last_channel: int, last_channel: int,
pretrained: bool, pretrained: bool,
progress: bool, progress: bool,
**kwargs: Any **kwargs: Any,
): ):
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
if pretrained: if pretrained:
......
import warnings import warnings
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F
from typing import Any
from torch import Tensor from torch import Tensor
from torch.nn import functional as F
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from torchvision.models.googlenet import (
GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls)
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableGoogLeNet', 'googlenet'] __all__ = ["QuantizableGoogLeNet", "googlenet"]
quant_model_urls = { quant_model_urls = {
# fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
'googlenet_fbgemm': 'https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth', "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
} }
...@@ -44,35 +43,35 @@ def googlenet( ...@@ -44,35 +43,35 @@ def googlenet(
was trained on ImageNet. Default: *False* was trained on ImageNet. Default: *False*
""" """
if pretrained: if pretrained:
if 'transform_input' not in kwargs: if "transform_input" not in kwargs:
kwargs['transform_input'] = True kwargs["transform_input"] = True
if 'aux_logits' not in kwargs: if "aux_logits" not in kwargs:
kwargs['aux_logits'] = False kwargs["aux_logits"] = False
if kwargs['aux_logits']: if kwargs["aux_logits"]:
warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' warnings.warn(
'so make sure to train them') "auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them"
original_aux_logits = kwargs['aux_logits'] )
kwargs['aux_logits'] = True original_aux_logits = kwargs["aux_logits"]
kwargs['init_weights'] = False kwargs["aux_logits"] = True
kwargs["init_weights"] = False
model = QuantizableGoogLeNet(**kwargs) model = QuantizableGoogLeNet(**kwargs)
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend # TODO use pretrained as a string to specify the backend
backend = 'fbgemm' backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else: else:
assert pretrained in [True, False] assert pretrained in [True, False]
if pretrained: if pretrained:
if quantize: if quantize:
model_url = quant_model_urls['googlenet' + '_' + backend] model_url = quant_model_urls["googlenet" + "_" + backend]
else: else:
model_url = model_urls['googlenet'] model_url = model_urls["googlenet"]
state_dict = load_state_dict_from_url(model_url, state_dict = load_state_dict_from_url(model_url, progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
...@@ -84,7 +83,6 @@ def googlenet( ...@@ -84,7 +83,6 @@ def googlenet(
class QuantizableBasicConv2d(BasicConv2d): class QuantizableBasicConv2d(BasicConv2d):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
self.relu = nn.ReLU() self.relu = nn.ReLU()
...@@ -100,10 +98,10 @@ class QuantizableBasicConv2d(BasicConv2d): ...@@ -100,10 +98,10 @@ class QuantizableBasicConv2d(BasicConv2d):
class QuantizableInception(Inception): class QuantizableInception(Inception):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInception, self).__init__( # type: ignore[misc] super(QuantizableInception, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, *args, **kwargs) conv_block=QuantizableBasicConv2d, *args, **kwargs
)
self.cat = nn.quantized.FloatFunctional() self.cat = nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
...@@ -115,9 +113,7 @@ class QuantizableInceptionAux(InceptionAux): ...@@ -115,9 +113,7 @@ class QuantizableInceptionAux(InceptionAux):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc] super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7) self.dropout = nn.Dropout(0.7)
...@@ -144,9 +140,7 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -144,9 +140,7 @@ class QuantizableGoogLeNet(GoogLeNet):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc] super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc]
blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
*args,
**kwargs
) )
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
......
import warnings import warnings
from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from typing import Any, List
from torchvision.models import inception as inception_module from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs from torchvision.models.inception import InceptionOutputs
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
...@@ -20,8 +20,7 @@ __all__ = [ ...@@ -20,8 +20,7 @@ __all__ = [
quant_model_urls = { quant_model_urls = {
# fp32 weights ported from TensorFlow, quantized in PyTorch # fp32 weights ported from TensorFlow, quantized in PyTorch
"inception_v3_google_fbgemm": "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
"https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
} }
...@@ -66,7 +65,7 @@ def inception_v3( ...@@ -66,7 +65,7 @@ def inception_v3(
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend # TODO use pretrained as a string to specify the backend
backend = 'fbgemm' backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else: else:
assert pretrained in [True, False] assert pretrained in [True, False]
...@@ -76,12 +75,11 @@ def inception_v3( ...@@ -76,12 +75,11 @@ def inception_v3(
if not original_aux_logits: if not original_aux_logits:
model.aux_logits = False model.aux_logits = False
model.AuxLogits = None model.AuxLogits = None
model_url = quant_model_urls['inception_v3_google' + '_' + backend] model_url = quant_model_urls["inception_v3_google" + "_" + backend]
else: else:
model_url = inception_module.model_urls['inception_v3_google'] model_url = inception_module.model_urls["inception_v3_google"]
state_dict = load_state_dict_from_url(model_url, state_dict = load_state_dict_from_url(model_url, progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
...@@ -111,9 +109,7 @@ class QuantizableInceptionA(inception_module.InceptionA): ...@@ -111,9 +109,7 @@ class QuantizableInceptionA(inception_module.InceptionA):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionA, self).__init__( # type: ignore[misc] super(QuantizableInceptionA, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
...@@ -126,9 +122,7 @@ class QuantizableInceptionB(inception_module.InceptionB): ...@@ -126,9 +122,7 @@ class QuantizableInceptionB(inception_module.InceptionB):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionB, self).__init__( # type: ignore[misc] super(QuantizableInceptionB, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
...@@ -141,9 +135,7 @@ class QuantizableInceptionC(inception_module.InceptionC): ...@@ -141,9 +135,7 @@ class QuantizableInceptionC(inception_module.InceptionC):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionC, self).__init__( # type: ignore[misc] super(QuantizableInceptionC, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
...@@ -156,9 +148,7 @@ class QuantizableInceptionD(inception_module.InceptionD): ...@@ -156,9 +148,7 @@ class QuantizableInceptionD(inception_module.InceptionD):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionD, self).__init__( # type: ignore[misc] super(QuantizableInceptionD, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
self.myop = nn.quantized.FloatFunctional() self.myop = nn.quantized.FloatFunctional()
...@@ -171,9 +161,7 @@ class QuantizableInceptionE(inception_module.InceptionE): ...@@ -171,9 +161,7 @@ class QuantizableInceptionE(inception_module.InceptionE):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionE, self).__init__( # type: ignore[misc] super(QuantizableInceptionE, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
self.myop1 = nn.quantized.FloatFunctional() self.myop1 = nn.quantized.FloatFunctional()
self.myop2 = nn.quantized.FloatFunctional() self.myop2 = nn.quantized.FloatFunctional()
...@@ -209,9 +197,7 @@ class QuantizableInceptionAux(inception_module.InceptionAux): ...@@ -209,9 +197,7 @@ class QuantizableInceptionAux(inception_module.InceptionAux):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableInceptionAux, self).__init__( # type: ignore[misc] super(QuantizableInceptionAux, self).__init__( # type: ignore[misc]
conv_block=QuantizableBasicConv2d, conv_block=QuantizableBasicConv2d, *args, **kwargs
*args,
**kwargs
) )
...@@ -233,8 +219,8 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -233,8 +219,8 @@ class QuantizableInception3(inception_module.Inception3):
QuantizableInceptionC, QuantizableInceptionC,
QuantizableInceptionD, QuantizableInceptionD,
QuantizableInceptionE, QuantizableInceptionE,
QuantizableInceptionAux QuantizableInceptionAux,
] ],
) )
self.quant = torch.quantization.QuantStub() self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub() self.dequant = torch.quantization.DeQuantStub()
......
from torch import nn
from torch import Tensor
from ..._internally_replaced_utils import load_state_dict_from_url
from typing import Any from typing import Any
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from torch import Tensor
from torch import nn
from torch.quantization import QuantStub, DeQuantStub, fuse_modules from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from .utils import _replace_relu, quantize_model from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation from ...ops.misc import ConvNormActivation
from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2'] __all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
quant_model_urls = { quant_model_urls = {
'mobilenet_v2_qnnpack': "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"
'https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth'
} }
...@@ -57,7 +55,7 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -57,7 +55,7 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvNormActivation: if type(m) == ConvNormActivation:
fuse_modules(m, ['0', '1', '2'], inplace=True) fuse_modules(m, ["0", "1", "2"], inplace=True)
if type(m) == QuantizableInvertedResidual: if type(m) == QuantizableInvertedResidual:
m.fuse_model() m.fuse_model()
...@@ -87,19 +85,18 @@ def mobilenet_v2( ...@@ -87,19 +85,18 @@ def mobilenet_v2(
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend # TODO use pretrained as a string to specify the backend
backend = 'qnnpack' backend = "qnnpack"
quantize_model(model, backend) quantize_model(model, backend)
else: else:
assert pretrained in [True, False] assert pretrained in [True, False]
if pretrained: if pretrained:
if quantize: if quantize:
model_url = quant_model_urls['mobilenet_v2_' + backend] model_url = quant_model_urls["mobilenet_v2_" + backend]
else: else:
model_url = model_urls['mobilenet_v2'] model_url = model_urls["mobilenet_v2"]
state_dict = load_state_dict_from_url(model_url, state_dict = load_state_dict_from_url(model_url, progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
from typing import Any, List, Optional
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation, SqueezeExcitation from ...ops.misc import ConvNormActivation, SqueezeExcitation
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\ from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
model_urls, _mobilenet_v3_conf
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
from typing import Any, List, Optional
from .utils import _replace_relu from .utils import _replace_relu
__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large'] __all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
quant_model_urls = { quant_model_urls = {
'mobilenet_v3_large_qnnpack': "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
"https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
} }
...@@ -29,7 +29,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation): ...@@ -29,7 +29,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
return self.skip_mul.mul(self._scale(input), input) return self.skip_mul.mul(self._scale(input), input)
def fuse_model(self) -> None: def fuse_model(self) -> None:
fuse_modules(self, ['fc1', 'activation'], inplace=True) fuse_modules(self, ["fc1", "activation"], inplace=True)
def _load_from_state_dict( def _load_from_state_dict(
self, self,
...@@ -45,7 +45,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation): ...@@ -45,7 +45,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
if version is None or version < 2: if version is None or version < 2:
default_state_dict = { default_state_dict = {
"scale_activation.activation_post_process.scale": torch.tensor([1.]), "scale_activation.activation_post_process.scale": torch.tensor([1.0]),
"scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
"scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
"scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
...@@ -69,11 +69,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation): ...@@ -69,11 +69,7 @@ class QuantizableSqueezeExcitation(SqueezeExcitation):
class QuantizableInvertedResidual(InvertedResidual): class QuantizableInvertedResidual(InvertedResidual):
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__( # type: ignore[misc] super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs) # type: ignore[misc]
se_layer=QuantizableSqueezeExcitation,
*args,
**kwargs
)
self.skip_add = nn.quantized.FloatFunctional() self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
...@@ -104,20 +100,15 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -104,20 +100,15 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvNormActivation: if type(m) == ConvNormActivation:
modules_to_fuse = ['0', '1'] modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) == nn.ReLU: if len(m) == 3 and type(m[2]) == nn.ReLU:
modules_to_fuse.append('2') modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True) fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation: elif type(m) == QuantizableSqueezeExcitation:
m.fuse_model() m.fuse_model()
def _load_weights( def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
arch: str,
model: QuantizableMobileNetV3,
model_url: Optional[str],
progress: bool
) -> None:
if model_url is None: if model_url is None:
raise ValueError("No checkpoint is available for {}".format(arch)) raise ValueError("No checkpoint is available for {}".format(arch))
state_dict = load_state_dict_from_url(model_url, progress=progress) state_dict = load_state_dict_from_url(model_url, progress=progress)
...@@ -138,14 +129,14 @@ def _mobilenet_v3_model( ...@@ -138,14 +129,14 @@ def _mobilenet_v3_model(
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
backend = 'qnnpack' backend = "qnnpack"
model.fuse_model() model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend) model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True) torch.quantization.prepare_qat(model, inplace=True)
if pretrained: if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress)
torch.quantization.convert(model, inplace=True) torch.quantization.convert(model, inplace=True)
model.eval() model.eval()
......
from typing import Any, Type, Union, List
import torch import torch
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any, Type, Union, List from torch.quantization import fuse_modules
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from torch.quantization import fuse_modules
from .utils import _replace_relu, quantize_model from .utils import _replace_relu, quantize_model
__all__ = ['QuantizableResNet', 'resnet18', 'resnet50', __all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"]
'resnext101_32x8d']
quant_model_urls = { quant_model_urls = {
'resnet18_fbgemm': "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
'https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth', "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
'resnet50_fbgemm': "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth',
'resnext101_32x8d_fbgemm':
'https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth',
} }
...@@ -45,10 +42,9 @@ class QuantizableBasicBlock(BasicBlock): ...@@ -45,10 +42,9 @@ class QuantizableBasicBlock(BasicBlock):
return out return out
def fuse_model(self) -> None: def fuse_model(self) -> None:
torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True)
['conv2', 'bn2']], inplace=True)
if self.downsample: if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
class QuantizableBottleneck(Bottleneck): class QuantizableBottleneck(Bottleneck):
...@@ -77,15 +73,12 @@ class QuantizableBottleneck(Bottleneck): ...@@ -77,15 +73,12 @@ class QuantizableBottleneck(Bottleneck):
return out return out
def fuse_model(self) -> None: def fuse_model(self) -> None:
fuse_modules(self, [['conv1', 'bn1', 'relu1'], fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True)
['conv2', 'bn2', 'relu2'],
['conv3', 'bn3']], inplace=True)
if self.downsample: if self.downsample:
torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True)
class QuantizableResNet(ResNet): class QuantizableResNet(ResNet):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
super(QuantizableResNet, self).__init__(*args, **kwargs) super(QuantizableResNet, self).__init__(*args, **kwargs)
...@@ -109,7 +102,7 @@ class QuantizableResNet(ResNet): ...@@ -109,7 +102,7 @@ class QuantizableResNet(ResNet):
and the model after modification is in floating point and the model after modification is in floating point
""" """
fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True) fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
for m in self.modules(): for m in self.modules():
if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock: if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
m.fuse_model() m.fuse_model()
...@@ -129,19 +122,18 @@ def _resnet( ...@@ -129,19 +122,18 @@ def _resnet(
_replace_relu(model) _replace_relu(model)
if quantize: if quantize:
# TODO use pretrained as a string to specify the backend # TODO use pretrained as a string to specify the backend
backend = 'fbgemm' backend = "fbgemm"
quantize_model(model, backend) quantize_model(model, backend)
else: else:
assert pretrained in [True, False] assert pretrained in [True, False]
if pretrained: if pretrained:
if quantize: if quantize:
model_url = quant_model_urls[arch + '_' + backend] model_url = quant_model_urls[arch + "_" + backend]
else: else:
model_url = model_urls[arch] model_url = model_urls[arch]
state_dict = load_state_dict_from_url(model_url, state_dict = load_state_dict_from_url(model_url, progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
...@@ -161,8 +153,7 @@ def resnet18( ...@@ -161,8 +153,7 @@ def resnet18(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs)
quantize, **kwargs)
def resnet50( def resnet50(
...@@ -180,8 +171,7 @@ def resnet50( ...@@ -180,8 +171,7 @@ def resnet50(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs)
quantize, **kwargs)
def resnext101_32x8d( def resnext101_32x8d(
...@@ -198,7 +188,6 @@ def resnext101_32x8d( ...@@ -198,7 +188,6 @@ def resnext101_32x8d(
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
quantize (bool): If True, return a quantized version of the model quantize (bool): If True, return a quantized version of the model
""" """
kwargs['groups'] = 32 kwargs["groups"] = 32
kwargs['width_per_group'] = 8 kwargs["width_per_group"] = 8
return _resnet('resnext101_32x8d', QuantizableBottleneck, [3, 4, 23, 3], return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs)
pretrained, progress, quantize, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment