Unverified Commit f71316fa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix mypy type annotations (#1696)



* Fix mypy type annotations

* follow torchscript Tuple type

* redefine torch_choice output type

* change the type in cached_grid_anchors

* minor bug
Co-authored-by: default avatarGuanheng Zhang <zhangguanheng@devfair0197.h2.fair>
Co-authored-by: default avatarGuanheng Zhang <zhangguanheng@learnfair0341.h2.fair>
parent 3ac864dc
......@@ -20,7 +20,7 @@ class BalancedPositiveNegativeSampler(object):
"""
def __init__(self, batch_size_per_image, positive_fraction):
# type: (int, float)
# type: (int, float) -> None
"""
Arguments:
batch_size_per_image (int): number of elements to be selected per image
......@@ -30,7 +30,7 @@ class BalancedPositiveNegativeSampler(object):
self.positive_fraction = positive_fraction
def __call__(self, matched_idxs):
# type: (List[Tensor])
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
"""
Arguments:
matched idxs: list of tensors containing -1, 0 or positive values.
......@@ -139,7 +139,7 @@ class BoxCoder(object):
"""
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
# type: (Tuple[float, float, float, float], float)
# type: (Tuple[float, float, float, float], float) -> None
"""
Arguments:
weights (4-element tuple)
......@@ -149,7 +149,7 @@ class BoxCoder(object):
self.bbox_xform_clip = bbox_xform_clip
def encode(self, reference_boxes, proposals):
# type: (List[Tensor], List[Tensor])
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
boxes_per_image = [len(b) for b in reference_boxes]
reference_boxes = torch.cat(reference_boxes, dim=0)
proposals = torch.cat(proposals, dim=0)
......@@ -173,7 +173,7 @@ class BoxCoder(object):
return targets
def decode(self, rel_codes, boxes):
# type: (Tensor, List[Tensor])
# type: (Tensor, List[Tensor]) -> Tensor
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
......@@ -251,7 +251,7 @@ class Matcher(object):
}
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
# type: (float, float, bool)
# type: (float, float, bool) -> None
"""
Args:
high_threshold (float): quality values greater than or equal to
......
......@@ -42,7 +42,7 @@ class GeneralizedRCNN(nn.Module):
return detections
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
"""
Arguments:
images (list[Tensor]): images to be processed
......
......@@ -14,7 +14,7 @@ class ImageList(object):
"""
def __init__(self, tensors, image_sizes):
# type: (Tensor, List[Tuple[int, int]])
# type: (Tensor, List[Tuple[int, int]]) -> None
"""
Arguments:
tensors (tensor)
......@@ -24,6 +24,6 @@ class ImageList(object):
self.image_sizes = image_sizes
def to(self, device):
# type: (Device) # noqa
# type: (Device) -> ImageList # noqa
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes)
......@@ -15,7 +15,7 @@ from torch.jit.annotations import Optional, List, Dict, Tuple
def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
......@@ -55,7 +55,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
def maskrcnn_inference(x, labels):
# type: (Tensor, List[Tensor])
# type: (Tensor, List[Tensor]) -> List[Tensor]
"""
From the results of the CNN, post process the masks
by taking the mask corresponding to the class with max
......@@ -85,7 +85,7 @@ def maskrcnn_inference(x, labels):
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
# type: (Tensor, Tensor, Tensor, int)
# type: (Tensor, Tensor, Tensor, int) -> Tensor
"""
Given segmentation masks and the bounding boxes corresponding
to the location of the masks in the image, this function
......@@ -100,7 +100,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor])
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) -> Tensor
"""
Arguments:
proposals (list[BoxList])
......@@ -133,7 +133,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs
def keypoints_to_heatmap(keypoints, rois, heatmap_size):
# type: (Tensor, Tensor, int)
# type: (Tensor, Tensor, int) -> Tuple[Tensor, Tensor]
offset_x = rois[:, 0]
offset_y = rois[:, 1]
scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
......@@ -277,7 +277,7 @@ def heatmaps_to_keypoints(maps, rois):
def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs):
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor])
# type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor
N, K, H, W = keypoint_logits.shape
assert H == W
discretization_size = H
......@@ -307,7 +307,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched
def keypointrcnn_inference(x, boxes):
# type: (Tensor, List[Tensor])
# type: (Tensor, List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
kp_probs = []
kp_scores = []
......@@ -323,7 +323,7 @@ def keypointrcnn_inference(x, boxes):
def _onnx_expand_boxes(boxes, scale):
# type: (Tensor, float)
# type: (Tensor, float) -> Tensor
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
h_half = (boxes[:, 3] - boxes[:, 1]) * .5
x_c = (boxes[:, 2] + boxes[:, 0]) * .5
......@@ -344,7 +344,7 @@ def _onnx_expand_boxes(boxes, scale):
# but are kept here for the moment while we need them
# temporarily for paste_mask_in_image
def expand_boxes(boxes, scale):
# type: (Tensor, float)
# type: (Tensor, float) -> Tensor
if torchvision._is_tracing():
return _onnx_expand_boxes(boxes, scale)
w_half = (boxes[:, 2] - boxes[:, 0]) * .5
......@@ -370,7 +370,7 @@ def expand_masks_tracing_scale(M, padding):
def expand_masks(mask, padding):
# type: (Tensor, int)
# type: (Tensor, int) -> Tuple[Tensor, float]
M = mask.shape[-1]
if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why
scale = expand_masks_tracing_scale(M, padding)
......@@ -381,7 +381,7 @@ def expand_masks(mask, padding):
def paste_mask_in_image(mask, box, im_h, im_w):
# type: (Tensor, Tensor, int, int)
# type: (Tensor, Tensor, int, int) -> Tensor
TO_REMOVE = 1
w = int(box[2] - box[0] + TO_REMOVE)
h = int(box[3] - box[1] + TO_REMOVE)
......@@ -459,7 +459,7 @@ def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w):
def paste_masks_in_image(masks, boxes, img_shape, padding=1):
# type: (Tensor, Tensor, Tuple[int, int], int)
# type: (Tensor, Tensor, Tuple[int, int], int) -> Tensor
masks, scale = expand_masks(masks, padding=padding)
boxes = expand_boxes(boxes, scale).to(dtype=torch.int64)
im_h, im_w = img_shape
......@@ -558,7 +558,7 @@ class RoIHeads(torch.nn.Module):
return True
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
# type: (List[Tensor], List[Tensor], List[Tensor])
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
matched_idxs = []
labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
......@@ -595,7 +595,7 @@ class RoIHeads(torch.nn.Module):
return matched_idxs, labels
def subsample(self, labels):
# type: (List[Tensor])
# type: (List[Tensor]) -> List[Tensor]
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
......@@ -606,7 +606,7 @@ class RoIHeads(torch.nn.Module):
return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor])
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
proposals = [
torch.cat((proposal, gt_box))
for proposal, gt_box in zip(proposals, gt_boxes)
......@@ -615,22 +615,25 @@ class RoIHeads(torch.nn.Module):
return proposals
def DELTEME_all(self, the_list):
# type: (List[bool])
# type: (List[bool]) -> bool
for i in the_list:
if not i:
return False
return True
def check_targets(self, targets):
# type: (Optional[List[Dict[str, Tensor]]])
# type: (Optional[List[Dict[str, Tensor]]]) -> None
assert targets is not None
assert self.DELTEME_all(["boxes" in t for t in targets])
assert self.DELTEME_all(["labels" in t for t in targets])
if self.has_mask():
assert self.DELTEME_all(["masks" in t for t in targets])
def select_training_samples(self, proposals, targets):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
def select_training_samples(self,
proposals, # type: List[Tensor]
targets # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
self.check_targets(targets)
assert targets is not None
dtype = proposals[0].dtype
......@@ -662,8 +665,13 @@ class RoIHeads(torch.nn.Module):
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
return proposals, matched_idxs, labels, regression_targets
def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes):
# type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]])
def postprocess_detections(self,
class_logits, # type: Tensor
box_regression, # type: Tensor
proposals, # type: List[Tensor]
image_shapes # type: List[Tuple[int, int]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
device = class_logits.device
num_classes = class_logits.shape[-1]
......@@ -715,8 +723,13 @@ class RoIHeads(torch.nn.Module):
return all_boxes, all_scores, all_labels
def forward(self, features, proposals, image_shapes, targets=None):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]])
def forward(self,
features, # type: Dict[str, Tensor]
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
"""
Arguments:
features (List[Tensor])
......
......@@ -75,7 +75,7 @@ class AnchorGenerator(nn.Module):
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) # noqa: F821
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
......@@ -88,7 +88,7 @@ class AnchorGenerator(nn.Module):
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
......@@ -114,7 +114,7 @@ class AnchorGenerator(nn.Module):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]])
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
......@@ -147,7 +147,7 @@ class AnchorGenerator(nn.Module):
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]])
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
......@@ -156,7 +156,7 @@ class AnchorGenerator(nn.Module):
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor])
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
......@@ -200,7 +200,7 @@ class RPNHead(nn.Module):
torch.nn.init.constant_(l.bias, 0)
def forward(self, x):
# type: (List[Tensor])
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
logits = []
bbox_reg = []
for feature in x:
......@@ -211,7 +211,7 @@ class RPNHead(nn.Module):
def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int)
# type: (Tensor, int, int, int, int, int) -> Tensor
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
......@@ -219,7 +219,7 @@ def permute_and_flatten(layer, N, A, C, H, W):
def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor])
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
......@@ -325,7 +325,7 @@ class RegionProposalNetwork(torch.nn.Module):
return self._post_nms_top_n['testing']
def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]])
# type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
......@@ -361,7 +361,7 @@ class RegionProposalNetwork(torch.nn.Module):
return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int])
# type: (Tensor, List[int]) -> Tensor
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
......@@ -376,7 +376,7 @@ class RegionProposalNetwork(torch.nn.Module):
return torch.cat(r, dim=1)
def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
num_images = proposals.shape[0]
device = proposals.device
# do not backprop throught objectness
......@@ -416,7 +416,7 @@ class RegionProposalNetwork(torch.nn.Module):
return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor])
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Arguments:
objectness (Tensor)
......@@ -453,8 +453,12 @@ class RegionProposalNetwork(torch.nn.Module):
return objectness_loss, box_loss
def forward(self, images, features, targets=None):
# type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]])
def forward(self,
images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
......
......@@ -76,8 +76,11 @@ class GeneralizedRCNNTransform(nn.Module):
self.image_mean = image_mean
self.image_std = image_std
def forward(self, images, targets=None):
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
def forward(self,
images, # type: List[Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
images = [img for img in images]
for i in range(len(images)):
image = images[i]
......@@ -109,7 +112,7 @@ class GeneralizedRCNNTransform(nn.Module):
return (image - mean[:, None, None]) / std[:, None, None]
def torch_choice(self, l):
# type: (List[int])
# type: (List[int]) -> int
"""
Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
......@@ -119,7 +122,7 @@ class GeneralizedRCNNTransform(nn.Module):
return l[index]
def resize(self, image, target):
# type: (Tensor, Optional[Dict[str, Tensor]])
# type: (Tensor, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
h, w = image.shape[-2:]
if self.training:
size = float(self.torch_choice(self.min_size))
......@@ -178,7 +181,7 @@ class GeneralizedRCNNTransform(nn.Module):
return maxes
def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int)
# type: (List[Tensor], int) -> Tensor
if torchvision._is_tracing():
# batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
......@@ -197,8 +200,12 @@ class GeneralizedRCNNTransform(nn.Module):
return batched_imgs
def postprocess(self, result, image_shapes, original_image_sizes):
# type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]])
def postprocess(self,
result, # type: List[Dict[str, Tensor]]
image_shapes, # type: List[Tuple[int, int]]
original_image_sizes # type: List[Tuple[int, int]]
):
# type: (...) -> List[Dict[str, Tensor]]
if self.training:
return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
......@@ -226,7 +233,7 @@ class GeneralizedRCNNTransform(nn.Module):
def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int])
# type: (Tensor, List[int], List[int]) -> Tensor
ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
......@@ -245,7 +252,7 @@ def resize_keypoints(keypoints, original_size, new_size):
def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int])
# type: (Tensor, List[int], List[int]) -> Tensor
ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
......
......@@ -6,7 +6,7 @@ import torchvision
@torch.jit.script
def nms(boxes, scores, iou_threshold):
# type: (Tensor, Tensor, float)
# type: (Tensor, Tensor, float) -> Tensor
"""
Performs non-maximum suppression (NMS) on the boxes according
to their intersection-over-union (IoU).
......@@ -43,7 +43,7 @@ def nms(boxes, scores, iou_threshold):
@torch.jit.script
def batched_nms(boxes, scores, idxs, iou_threshold):
# type: (Tensor, Tensor, Tensor, float)
# type: (Tensor, Tensor, Tensor, float) -> Tensor
"""
Performs non-maximum suppression in a batched fashion.
......@@ -85,7 +85,7 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
def remove_small_boxes(boxes, min_size):
# type: (Tensor, float)
# type: (Tensor, float) -> Tensor
"""
Remove boxes which contains at least one side smaller than min_size.
......@@ -104,7 +104,7 @@ def remove_small_boxes(boxes, min_size):
def clip_boxes_to_image(boxes, size):
# type: (Tensor, Tuple[int, int])
# type: (Tensor, Tuple[int, int]) -> Tensor
"""
Clip boxes so that they lie inside an image of size `size`.
......
......@@ -67,7 +67,7 @@ class FeaturePyramidNetwork(nn.Module):
self.extra_blocks = extra_blocks
def get_result_from_inner_blocks(self, x, idx):
# type: (Tensor, int)
# type: (Tensor, int) -> Tensor
"""
This is equivalent to self.inner_blocks[idx](x),
but torchscript doesn't support this yet
......@@ -86,7 +86,7 @@ class FeaturePyramidNetwork(nn.Module):
return out
def get_result_from_layer_blocks(self, x, idx):
# type: (Tensor, int)
# type: (Tensor, int) -> Tensor
"""
This is equivalent to self.layer_blocks[idx](x),
but torchscript doesn't support this yet
......@@ -105,7 +105,7 @@ class FeaturePyramidNetwork(nn.Module):
return out
def forward(self, x):
# type: (Dict[str, Tensor])
# type: (Dict[str, Tensor]) -> Dict[str, Tensor]
"""
Computes the FPN for a set of feature maps.
......@@ -164,7 +164,7 @@ class LastLevelMaxPool(ExtraFPNBlock):
Applies a max_pool2d on top of the last feature map
"""
def forward(self, x, y, names):
# type: (List[Tensor], List[Tensor], List[str])
# type: (List[Tensor], List[Tensor], List[str]) -> Tuple[List[Tensor], List[str]]
names.append("pool")
x.append(F.max_pool2d(x[-1], 1, 2, 0))
return x, names
......
......@@ -33,7 +33,7 @@ def _onnx_merge_levels(levels, unmerged_results):
# TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744
def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float)
# type: (int, int, int, int, float) -> LevelMapper
return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)
......@@ -51,7 +51,7 @@ class LevelMapper(object):
"""
def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6):
# type: (int, int, int, int, float)
# type: (int, int, int, int, float) -> None
self.k_min = k_min
self.k_max = k_max
self.s0 = canonical_scale
......@@ -59,7 +59,7 @@ class LevelMapper(object):
self.eps = eps
def __call__(self, boxlists):
# type: (List[Tensor])
# type: (List[Tensor]) -> Tensor
"""
Arguments:
boxlists (list[BoxList])
......@@ -118,7 +118,7 @@ class MultiScaleRoIAlign(nn.Module):
self.map_levels = None
def convert_to_roi_format(self, boxes):
# type: (List[Tensor])
# type: (List[Tensor]) -> Tensor
concat_boxes = torch.cat(boxes, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
......@@ -132,7 +132,7 @@ class MultiScaleRoIAlign(nn.Module):
return rois
def infer_scale(self, feature, original_size):
# type: (Tensor, List[int])
# type: (Tensor, List[int]) -> float
# assumption: the scale is of the form 2 ** (-k), with k integer
size = feature.shape[-2:]
possible_scales = torch.jit.annotate(List[float], [])
......@@ -144,7 +144,7 @@ class MultiScaleRoIAlign(nn.Module):
return possible_scales[0]
def setup_scales(self, features, image_shapes):
# type: (List[Tensor], List[Tuple[int, int]])
# type: (List[Tensor], List[Tuple[int, int]]) -> None
assert len(image_shapes) != 0
max_x = 0
max_y = 0
......@@ -162,7 +162,7 @@ class MultiScaleRoIAlign(nn.Module):
self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max))
def forward(self, x, boxes, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]])
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> Tensor
"""
Arguments:
x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment