Unverified Commit f71316fa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix mypy type annotations (#1696)



* Fix mypy type annotations

* follow torchscript Tuple type

* redefine torch_choice output type

* change the type in cached_grid_anchors

* minor bug
Co-authored-by: default avatarGuanheng Zhang <zhangguanheng@devfair0197.h2.fair>
Co-authored-by: default avatarGuanheng Zhang <zhangguanheng@learnfair0341.h2.fair>
parent 3ac864dc
...@@ -20,7 +20,7 @@ class BalancedPositiveNegativeSampler(object): ...@@ -20,7 +20,7 @@ class BalancedPositiveNegativeSampler(object):
""" """
def __init__(self, batch_size_per_image, positive_fraction): def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float) # type: (int, float) -> None
""" """
Arguments: Arguments:
batch_size_per_image (int): number of elements to be selected per image batch_size_per_image (int): number of elements to be selected per image
...@@ -30,7 +30,7 @@ class BalancedPositiveNegativeSampler(object): ...@@ -30,7 +30,7 @@ class BalancedPositiveNegativeSampler(object):
self.positive_fraction = positive_fraction self.positive_fraction = positive_fraction
def __call__(self, matched_idxs): def __call__(self, matched_idxs):
# type: (List[Tensor]) # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
""" """
Arguments: Arguments:
matched idxs: list of tensors containing -1, 0 or positive values. matched idxs: list of tensors containing -1, 0 or positive values.
...@@ -139,7 +139,7 @@ class BoxCoder(object): ...@@ -139,7 +139,7 @@ class BoxCoder(object):
""" """
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
# type: (Tuple[float, float, float, float], float) # type: (Tuple[float, float, float, float], float) -> None
""" """
Arguments: Arguments:
weights (4-element tuple) weights (4-element tuple)
...@@ -149,7 +149,7 @@ class BoxCoder(object): ...@@ -149,7 +149,7 @@ class BoxCoder(object):
self.bbox_xform_clip = bbox_xform_clip self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals): def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor]) # type: (List[Tensor], List[Tensor]) -> List[Tensor]
boxes_per_image = [len(b) for b in reference_boxes] boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0) reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0) proposals = torch.cat(proposals, dim=0)
...@@ -173,7 +173,7 @@ class BoxCoder(object): ...@@ -173,7 +173,7 @@ class BoxCoder(object):
return targets return targets
def decode(self, rel_codes, boxes): def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor]) # type: (Tensor, List[Tensor]) -> Tensor
assert isinstance(boxes, (list, tuple)) assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor) assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes] boxes_per_image = [b.size(0) for b in boxes]
...@@ -251,7 +251,7 @@ class Matcher(object): ...@@ -251,7 +251,7 @@ class Matcher(object):
} }
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool) # type: (float, float, bool) -> None
""" """
Args: Args:
high_threshold (float): quality values greater than or equal to high_threshold (float): quality values greater than or equal to
......
...@@ -42,7 +42,7 @@ class GeneralizedRCNN(nn.Module): ...@@ -42,7 +42,7 @@ class GeneralizedRCNN(nn.Module):
return detections return detections
def forward(self, images, targets=None): def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
""" """
Arguments: Arguments:
images (list[Tensor]): images to be processed images (list[Tensor]): images to be processed
......
...@@ -14,7 +14,7 @@ class ImageList(object): ...@@ -14,7 +14,7 @@ class ImageList(object):
""" """
def __init__(self, tensors, image_sizes): def __init__(self, tensors, image_sizes):
# type: (Tensor, List[Tuple[int, int]]) # type: (Tensor, List[Tuple[int, int]]) -> None
""" """
Arguments: Arguments:
tensors (tensor) tensors (tensor)
...@@ -24,6 +24,6 @@ class ImageList(object): ...@@ -24,6 +24,6 @@ class ImageList(object):
self.image_sizes = image_sizes self.image_sizes = image_sizes
def to(self, device): def to(self, device):
# type: (Device) # noqa # type: (Device) -> ImageList # noqa
cast_tensor = self.tensors.to(device) cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes) return ImageList(cast_tensor, self.image_sizes)
...@@ -15,7 +15,7 @@ from torch.jit.annotations import Optional, List, Dict, Tuple ...@@ -15,7 +15,7 @@ from torch.jit.annotations import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
""" """
Computes the loss for Faster R-CNN. Computes the loss for Faster R-CNN.
...@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): ...@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
def maskrcnn_inference(x, labels): def maskrcnn_inference(x, labels):
# type: (Tensor, List[Tensor]) # type: (Tensor, List[Tensor]) -> List[Tensor]
""" """
From the results of the CNN, post process the masks From the results of the CNN, post process the masks
by taking the mask corresponding to the class with max by taking the mask corresponding to the class with max
...@@ -85,7 +85,7 @@ def maskrcnn_inference(x, labels): ...@@ -85,7 +85,7 @@ def maskrcnn_inference(x, labels):
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
# type: (Tensor, Tensor, Tensor, int) # type: (Tensor, Tensor, Tensor, int) -> Tensor
""" """
Given segmentation masks and the bounding boxes corresponding Given segmentation masks and the bounding boxes corresponding
to the location of the masks in the image, this function to the location of the masks in the image, this function
...@@ -100,7 +100,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): ...@@ -100,7 +100,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
""" """
Arguments: Arguments:
proposals (list[BoxList]) proposals (list[BoxList])
...@@ -133,7 +133,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs ...@@ -133,7 +133,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
def keypoints_to_heatmap(keypoints, rois, heatmap_size): def keypoints_to_heatmap(keypoints, rois, heatmap_size):
# type: (Tensor, Tensor, int) # type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
offset_x = rois[:, 0] offset_x = rois[:, 0]
offset_y = rois[:, 1] offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
...@@ -277,7 +277,7 @@ def heatmaps_to_keypoints(maps, rois): ...@@ -277,7 +277,7 @@ def heatmaps_to_keypoints(maps, rois):
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
N, K, H, W = keypoint_logits.shape N, K, H, W = keypoint_logits.shape
assert H == W assert H == W
discretization_size = H discretization_size = H
...@@ -307,7 +307,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched ...@@ -307,7 +307,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
def keypointrcnn_inference(x, boxes): def keypointrcnn_inference(x, boxes):
# type: (Tensor, List[Tensor]) # type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
kp_probs = [] kp_probs = []
kp_scores = [] kp_scores = []
...@@ -323,7 +323,7 @@ def keypointrcnn_inference(x, boxes): ...@@ -323,7 +323,7 @@ def keypointrcnn_inference(x, boxes):
def _onnx_expand_boxes(boxes, scale): def _onnx_expand_boxes(boxes, scale):
# type: (Tensor, float) # type: (Tensor, float) -> Tensor
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5 h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5 x_c = (boxes[:, 2] + boxes[:, 0]) * .5
...@@ -344,7 +344,7 @@ def _onnx_expand_boxes(boxes, scale): ...@@ -344,7 +344,7 @@ def _onnx_expand_boxes(boxes, scale):
# but are kept here for the moment while we need them # but are kept here for the moment while we need them
# temporarily for paste_mask_in_image # temporarily for paste_mask_in_image
def expand_boxes(boxes, scale): def expand_boxes(boxes, scale):
# type: (Tensor, float) # type: (Tensor, float) -> Tensor
if torchvision._is_tracing(): if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale) return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 w_half = (boxes[:, 2] - boxes[:, 0]) * .5
...@@ -370,7 +370,7 @@ def expand_masks_tracing_scale(M, padding): ...@@ -370,7 +370,7 @@ def expand_masks_tracing_scale(M, padding):
def expand_masks(mask, padding): def expand_masks(mask, padding):
# type: (Tensor, int) # type: (Tensor, int) -> Tuple[Tensor, float]
M = mask.shape[-1] M = mask.shape[-1]
if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
scale = expand_masks_tracing_scale(M, padding) scale = expand_masks_tracing_scale(M, padding)
...@@ -381,7 +381,7 @@ def expand_masks(mask, padding): ...@@ -381,7 +381,7 @@ def expand_masks(mask, padding):
def paste_mask_in_image(mask, box, im_h, im_w): def paste_mask_in_image(mask, box, im_h, im_w):
# type: (Tensor, Tensor, int, int) # type: (Tensor, Tensor, int, int) -> Tensor
TO_REMOVE = 1 TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE) w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE) h = int(box[3] - box[1] + TO_REMOVE)
...@@ -459,7 +459,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): ...@@ -459,7 +459,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
def paste_masks_in_image(masks, boxes, img_shape, padding=1): def paste_masks_in_image(masks, boxes, img_shape, padding=1):
# type: (Tensor, Tensor, Tuple[int, int], int) # type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
masks, scale = expand_masks(masks, padding=padding) masks, scale = expand_masks(masks, padding=padding)
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
im_h, im_w = img_shape im_h, im_w = img_shape
...@@ -558,7 +558,7 @@ class RoIHeads(torch.nn.Module): ...@@ -558,7 +558,7 @@ class RoIHeads(torch.nn.Module):
return True return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
# type: (List[Tensor], List[Tensor], List[Tensor]) # type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
matched_idxs = [] matched_idxs = []
labels = [] labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
...@@ -595,7 +595,7 @@ class RoIHeads(torch.nn.Module): ...@@ -595,7 +595,7 @@ class RoIHeads(torch.nn.Module):
return matched_idxs, labels return matched_idxs, labels
def subsample(self, labels): def subsample(self, labels):
# type: (List[Tensor]) # type: (List[Tensor]) -> List[Tensor]
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = [] sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate( for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
...@@ -606,7 +606,7 @@ class RoIHeads(torch.nn.Module): ...@@ -606,7 +606,7 @@ class RoIHeads(torch.nn.Module):
return sampled_inds return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes): def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor]) # type: (List[Tensor], List[Tensor]) -> List[Tensor]
proposals = [ proposals = [
torch.cat((proposal, gt_box)) torch.cat((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes) for proposal, gt_box in zip(proposals, gt_boxes)
...@@ -615,22 +615,25 @@ class RoIHeads(torch.nn.Module): ...@@ -615,22 +615,25 @@ class RoIHeads(torch.nn.Module):
return proposals return proposals
def DELTEME_all(self, the_list): def DELTEME_all(self, the_list):
# type: (List[bool]) # type: (List[bool]) -> bool
for i in the_list: for i in the_list:
if not i: if not i:
return False return False
return True return True
def check_targets(self, targets): def check_targets(self, targets):
# type: (Optional[List[Dict[str, Tensor]]]) # type: (Optional[List[Dict[str, Tensor]]]) -> None
assert targets is not None assert targets is not None
assert self.DELTEME_all(["boxes" in t for t in targets]) assert self.DELTEME_all(["boxes" in t for t in targets])
assert self.DELTEME_all(["labels" in t for t in targets]) assert self.DELTEME_all(["labels" in t for t in targets])
if self.has_mask(): if self.has_mask():
assert self.DELTEME_all(["masks" in t for t in targets]) assert self.DELTEME_all(["masks" in t for t in targets])
def select_training_samples(self, proposals, targets): def select_training_samples(self,
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) proposals, # type: List[Tensor]
targets # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
self.check_targets(targets) self.check_targets(targets)
assert targets is not None assert targets is not None
dtype = proposals[0].dtype dtype = proposals[0].dtype
...@@ -662,8 +665,13 @@ class RoIHeads(torch.nn.Module): ...@@ -662,8 +665,13 @@ class RoIHeads(torch.nn.Module):
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
return proposals, matched_idxs, labels, regression_targets return proposals, matched_idxs, labels, regression_targets
def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes): def postprocess_detections(self,
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]]) class_logits, # type: Tensor
box_regression, # type: Tensor
proposals, # type: List[Tensor]
image_shapes # type: List[Tuple[int, int]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
device = class_logits.device device = class_logits.device
num_classes = class_logits.shape[-1] num_classes = class_logits.shape[-1]
...@@ -715,8 +723,13 @@ class RoIHeads(torch.nn.Module): ...@@ -715,8 +723,13 @@ class RoIHeads(torch.nn.Module):
return all_boxes, all_scores, all_labels return all_boxes, all_scores, all_labels
def forward(self, features, proposals, image_shapes, targets=None): def forward(self,
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]]) features, # type: Dict[str, Tensor]
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
""" """
Arguments: Arguments:
features (List[Tensor]) features (List[Tensor])
......
...@@ -75,7 +75,7 @@ class AnchorGenerator(nn.Module): ...@@ -75,7 +75,7 @@ class AnchorGenerator(nn.Module):
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor. # This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) # noqa: F821 # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
...@@ -88,7 +88,7 @@ class AnchorGenerator(nn.Module): ...@@ -88,7 +88,7 @@ class AnchorGenerator(nn.Module):
return base_anchors.round() return base_anchors.round()
def set_cell_anchors(self, dtype, device): def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821 # type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None: if self.cell_anchors is not None:
cell_anchors = self.cell_anchors cell_anchors = self.cell_anchors
assert cell_anchors is not None assert cell_anchors is not None
...@@ -114,7 +114,7 @@ class AnchorGenerator(nn.Module): ...@@ -114,7 +114,7 @@ class AnchorGenerator(nn.Module):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides): def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = [] anchors = []
cell_anchors = self.cell_anchors cell_anchors = self.cell_anchors
assert cell_anchors is not None assert cell_anchors is not None
...@@ -147,7 +147,7 @@ class AnchorGenerator(nn.Module): ...@@ -147,7 +147,7 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def cached_grid_anchors(self, grid_sizes, strides): def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides) key = str(grid_sizes) + str(strides)
if key in self._cache: if key in self._cache:
return self._cache[key] return self._cache[key]
...@@ -156,7 +156,7 @@ class AnchorGenerator(nn.Module): ...@@ -156,7 +156,7 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def forward(self, image_list, feature_maps): def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) # type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
...@@ -200,7 +200,7 @@ class RPNHead(nn.Module): ...@@ -200,7 +200,7 @@ class RPNHead(nn.Module):
torch.nn.init.constant_(l.bias, 0) torch.nn.init.constant_(l.bias, 0)
def forward(self, x): def forward(self, x):
# type: (List[Tensor]) # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
logits = [] logits = []
bbox_reg = [] bbox_reg = []
for feature in x: for feature in x:
...@@ -211,7 +211,7 @@ class RPNHead(nn.Module): ...@@ -211,7 +211,7 @@ class RPNHead(nn.Module):
def permute_and_flatten(layer, N, A, C, H, W): def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int) # type: (Tensor, int, int, int, int, int) -> Tensor
layer = layer.view(N, -1, C, H, W) layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2) layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C) layer = layer.reshape(N, -1, C)
...@@ -219,7 +219,7 @@ def permute_and_flatten(layer, N, A, C, H, W): ...@@ -219,7 +219,7 @@ def permute_and_flatten(layer, N, A, C, H, W):
def concat_box_prediction_layers(box_cls, box_regression): def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor]) # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
box_cls_flattened = [] box_cls_flattened = []
box_regression_flattened = [] box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the # for each feature level, permute the outputs to make them be in the
...@@ -325,7 +325,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -325,7 +325,7 @@ class RegionProposalNetwork(torch.nn.Module):
return self._post_nms_top_n['testing'] return self._post_nms_top_n['testing']
def assign_targets_to_anchors(self, anchors, targets): def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]]) # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
labels = [] labels = []
matched_gt_boxes = [] matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets): for anchors_per_image, targets_per_image in zip(anchors, targets):
...@@ -361,7 +361,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -361,7 +361,7 @@ class RegionProposalNetwork(torch.nn.Module):
return labels, matched_gt_boxes return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level): def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int]) # type: (Tensor, List[int]) -> Tensor
r = [] r = []
offset = 0 offset = 0
for ob in objectness.split(num_anchors_per_level, 1): for ob in objectness.split(num_anchors_per_level, 1):
...@@ -376,7 +376,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -376,7 +376,7 @@ class RegionProposalNetwork(torch.nn.Module):
return torch.cat(r, dim=1) return torch.cat(r, dim=1)
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
num_images = proposals.shape[0] num_images = proposals.shape[0]
device = proposals.device device = proposals.device
# do not backprop throught objectness # do not backprop throught objectness
...@@ -416,7 +416,7 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -416,7 +416,7 @@ class RegionProposalNetwork(torch.nn.Module):
return final_boxes, final_scores return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
""" """
Arguments: Arguments:
objectness (Tensor) objectness (Tensor)
...@@ -453,8 +453,12 @@ class RegionProposalNetwork(torch.nn.Module): ...@@ -453,8 +453,12 @@ class RegionProposalNetwork(torch.nn.Module):
return objectness_loss, box_loss return objectness_loss, box_loss
def forward(self, images, features, targets=None): def forward(self,
# type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]]) images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
""" """
Arguments: Arguments:
images (ImageList): images for which we want to compute the predictions images (ImageList): images for which we want to compute the predictions
......
...@@ -76,8 +76,11 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -76,8 +76,11 @@ class GeneralizedRCNNTransform(nn.Module):
self.image_mean = image_mean self.image_mean = image_mean
self.image_std = image_std self.image_std = image_std
def forward(self, images, targets=None): def forward(self,
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) images, # type: List[Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
images = [img for img in images] images = [img for img in images]
for i in range(len(images)): for i in range(len(images)):
image = images[i] image = images[i]
...@@ -109,7 +112,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -109,7 +112,7 @@ class GeneralizedRCNNTransform(nn.Module):
return (image - mean[:, None, None]) / std[:, None, None] return (image - mean[:, None, None]) / std[:, None, None]
def torch_choice(self, l): def torch_choice(self, l):
# type: (List[int]) # type: (List[int]) -> int
""" """
Implements `random.choice` via torch ops so it can be compiled with Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
...@@ -119,7 +122,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -119,7 +122,7 @@ class GeneralizedRCNNTransform(nn.Module):
return l[index] return l[index]
def resize(self, image, target): def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]]) # type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
h, w = image.shape[-2:] h, w = image.shape[-2:]
if self.training: if self.training:
size = float(self.torch_choice(self.min_size)) size = float(self.torch_choice(self.min_size))
...@@ -178,7 +181,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -178,7 +181,7 @@ class GeneralizedRCNNTransform(nn.Module):
return maxes return maxes
def batch_images(self, images, size_divisible=32): def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) # type: (List[Tensor], int) -> Tensor
if torchvision._is_tracing(): if torchvision._is_tracing():
# batch_images() does not export well to ONNX # batch_images() does not export well to ONNX
# call _onnx_batch_images() instead # call _onnx_batch_images() instead
...@@ -197,8 +200,12 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -197,8 +200,12 @@ class GeneralizedRCNNTransform(nn.Module):
return batched_imgs return batched_imgs
def postprocess(self, result, image_shapes, original_image_sizes): def postprocess(self,
# type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]]) result, # type: List[Dict[str, Tensor]]
image_shapes, # type: List[Tuple[int, int]]
original_image_sizes # type: List[Tuple[int, int]]
):
# type: (...) -> List[Dict[str, Tensor]]
if self.training: if self.training:
return result return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
...@@ -226,7 +233,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -226,7 +233,7 @@ class GeneralizedRCNNTransform(nn.Module):
def resize_keypoints(keypoints, original_size, new_size): def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int]) # type: (Tensor, List[int], List[int]) -> Tensor
ratios = [ ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
...@@ -245,7 +252,7 @@ def resize_keypoints(keypoints, original_size, new_size): ...@@ -245,7 +252,7 @@ def resize_keypoints(keypoints, original_size, new_size):
def resize_boxes(boxes, original_size, new_size): def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int]) # type: (Tensor, List[int], List[int]) -> Tensor
ratios = [ ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
......
...@@ -6,7 +6,7 @@ import torchvision ...@@ -6,7 +6,7 @@ import torchvision
@torch.jit.script @torch.jit.script
def nms(boxes, scores, iou_threshold): def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float) # type: (Tensor, Tensor, float) -> Tensor
""" """
Performs non-maximum suppression (NMS) on the boxes according Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU). to their intersection-over-union (IoU).
...@@ -43,7 +43,7 @@ def nms(boxes, scores, iou_threshold): ...@@ -43,7 +43,7 @@ def nms(boxes, scores, iou_threshold):
@torch.jit.script @torch.jit.script
def batched_nms(boxes, scores, idxs, iou_threshold): def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float) # type: (Tensor, Tensor, Tensor, float) -> Tensor
""" """
Performs non-maximum suppression in a batched fashion. Performs non-maximum suppression in a batched fashion.
...@@ -85,7 +85,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold): ...@@ -85,7 +85,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
def remove_small_boxes(boxes, min_size): def remove_small_boxes(boxes, min_size):
# type: (Tensor, float) # type: (Tensor, float) -> Tensor
""" """
Remove boxes which contains at least one side smaller than min_size. Remove boxes which contains at least one side smaller than min_size.
...@@ -104,7 +104,7 @@ def remove_small_boxes(boxes, min_size): ...@@ -104,7 +104,7 @@ def remove_small_boxes(boxes, min_size):
def clip_boxes_to_image(boxes, size): def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int]) # type: (Tensor, Tuple[int, int]) -> Tensor
""" """
Clip boxes so that they lie inside an image of size `size`. Clip boxes so that they lie inside an image of size `size`.
......
...@@ -67,7 +67,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -67,7 +67,7 @@ class FeaturePyramidNetwork(nn.Module):
self.extra_blocks = extra_blocks self.extra_blocks = extra_blocks
def get_result_from_inner_blocks(self, x, idx): def get_result_from_inner_blocks(self, x, idx):
# type: (Tensor, int) # type: (Tensor, int) -> Tensor
""" """
This is equivalent to self.inner_blocks[idx](x), This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet but torchscript doesn't support this yet
...@@ -86,7 +86,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -86,7 +86,7 @@ class FeaturePyramidNetwork(nn.Module):
return out return out
def get_result_from_layer_blocks(self, x, idx): def get_result_from_layer_blocks(self, x, idx):
# type: (Tensor, int) # type: (Tensor, int) -> Tensor
""" """
This is equivalent to self.layer_blocks[idx](x), This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet but torchscript doesn't support this yet
...@@ -105,7 +105,7 @@ class FeaturePyramidNetwork(nn.Module): ...@@ -105,7 +105,7 @@ class FeaturePyramidNetwork(nn.Module):
return out return out
def forward(self, x): def forward(self, x):
# type: (Dict[str, Tensor]) # type: (Dict[str, Tensor]) -> Dict[str, Tensor]
""" """
Computes the FPN for a set of feature maps. Computes the FPN for a set of feature maps.
...@@ -164,7 +164,7 @@ class LastLevelMaxPool(ExtraFPNBlock): ...@@ -164,7 +164,7 @@ class LastLevelMaxPool(ExtraFPNBlock):
Applies a max_pool2d on top of the last feature map Applies a max_pool2d on top of the last feature map
""" """
def forward(self, x, y, names): def forward(self, x, y, names):
# type: (List[Tensor], List[Tensor], List[str]) # type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
names.append("pool") names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0)) x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names return x, names
......
...@@ -33,7 +33,7 @@ def _onnx_merge_levels(levels, unmerged_results): ...@@ -33,7 +33,7 @@ def _onnx_merge_levels(levels, unmerged_results):
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744 # TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float) # type: (int, int, int, int, float) -> LevelMapper
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps) return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
...@@ -51,7 +51,7 @@ class LevelMapper(object): ...@@ -51,7 +51,7 @@ class LevelMapper(object):
""" """
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float) # type: (int, int, int, int, float) -> None
self.k_min = k_min self.k_min = k_min
self.k_max = k_max self.k_max = k_max
self.s0 = canonical_scale self.s0 = canonical_scale
...@@ -59,7 +59,7 @@ class LevelMapper(object): ...@@ -59,7 +59,7 @@ class LevelMapper(object):
self.eps = eps self.eps = eps
def __call__(self, boxlists): def __call__(self, boxlists):
# type: (List[Tensor]) # type: (List[Tensor]) -> Tensor
""" """
Arguments: Arguments:
boxlists (list[BoxList]) boxlists (list[BoxList])
...@@ -118,7 +118,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -118,7 +118,7 @@ class MultiScaleRoIAlign(nn.Module):
self.map_levels = None self.map_levels = None
def convert_to_roi_format(self, boxes): def convert_to_roi_format(self, boxes):
# type: (List[Tensor]) # type: (List[Tensor]) -> Tensor
concat_boxes = torch.cat(boxes, dim=0) concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat( ids = torch.cat(
...@@ -132,7 +132,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -132,7 +132,7 @@ class MultiScaleRoIAlign(nn.Module):
return rois return rois
def infer_scale(self, feature, original_size): def infer_scale(self, feature, original_size):
# type: (Tensor, List[int]) # type: (Tensor, List[int]) -> float
# assumption: the scale is of the form 2 ** (-k), with k integer # assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:] size = feature.shape[-2:]
possible_scales = torch.jit.annotate(List[float], []) possible_scales = torch.jit.annotate(List[float], [])
...@@ -144,7 +144,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -144,7 +144,7 @@ class MultiScaleRoIAlign(nn.Module):
return possible_scales[0] return possible_scales[0]
def setup_scales(self, features, image_shapes): def setup_scales(self, features, image_shapes):
# type: (List[Tensor], List[Tuple[int, int]]) # type: (List[Tensor], List[Tuple[int, int]]) -> None
assert len(image_shapes) != 0 assert len(image_shapes) != 0
max_x = 0 max_x = 0
max_y = 0 max_y = 0
...@@ -162,7 +162,7 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -162,7 +162,7 @@ class MultiScaleRoIAlign(nn.Module):
self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max)) self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
def forward(self, x, boxes, image_shapes): def forward(self, x, boxes, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Tensor
""" """
Arguments: Arguments:
x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment