Unverified Commit b3adace6 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding Python type hints, correcting incorrect types, removing unnecessary...

Adding Python type hints, correcting incorrect types, removing unnecessary vars and simplifying code. (#3045)
parent b8b08ac3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch import nn, Tensor
from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList
......@@ -56,8 +56,8 @@ class AnchorGenerator(nn.Module):
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")):
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
......@@ -69,8 +69,7 @@ class AnchorGenerator(nn.Module):
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
......@@ -95,8 +94,7 @@ class AnchorGenerator(nn.Module):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
......@@ -134,8 +132,7 @@ class AnchorGenerator(nn.Module):
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
......@@ -143,8 +140,7 @@ class AnchorGenerator(nn.Module):
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
......@@ -153,10 +149,8 @@ class AnchorGenerator(nn.Module):
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
for i in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment