Unverified Commit b3adace6 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding Python type hints, correcting incorrect types, removing unnecessary...

Adding Python type hints, correcting incorrect types, removing unnecessary vars and simplifying code. (#3045)
parent b8b08ac3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch import torch
from torch import nn from torch import nn, Tensor
from torch.jit.annotations import List, Optional, Dict from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList from .image_list import ImageList
...@@ -56,8 +56,8 @@ class AnchorGenerator(nn.Module): ...@@ -56,8 +56,8 @@ class AnchorGenerator(nn.Module):
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor. # This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32,
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 device: torch.device = torch.device("cpu")):
scales = torch.as_tensor(scales, dtype=dtype, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
...@@ -69,8 +69,7 @@ class AnchorGenerator(nn.Module): ...@@ -69,8 +69,7 @@ class AnchorGenerator(nn.Module):
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round() return base_anchors.round()
def set_cell_anchors(self, dtype, device): def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None: if self.cell_anchors is not None:
cell_anchors = self.cell_anchors cell_anchors = self.cell_anchors
assert cell_anchors is not None assert cell_anchors is not None
...@@ -95,8 +94,7 @@ class AnchorGenerator(nn.Module): ...@@ -95,8 +94,7 @@ class AnchorGenerator(nn.Module):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides): def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = [] anchors = []
cell_anchors = self.cell_anchors cell_anchors = self.cell_anchors
assert cell_anchors is not None assert cell_anchors is not None
...@@ -134,8 +132,7 @@ class AnchorGenerator(nn.Module): ...@@ -134,8 +132,7 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def cached_grid_anchors(self, grid_sizes, strides): def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides) key = str(grid_sizes) + str(strides)
if key in self._cache: if key in self._cache:
return self._cache[key] return self._cache[key]
...@@ -143,8 +140,7 @@ class AnchorGenerator(nn.Module): ...@@ -143,8 +140,7 @@ class AnchorGenerator(nn.Module):
self._cache[key] = anchors self._cache[key] = anchors
return anchors return anchors
def forward(self, image_list, feature_maps): def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
...@@ -153,10 +149,8 @@ class AnchorGenerator(nn.Module): ...@@ -153,10 +149,8 @@ class AnchorGenerator(nn.Module):
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], []) anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes): for i in range(len(image_list.image_sizes)):
anchors_in_image = [] anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image) anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks. # Clear the cache in case that memory leaks.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment