Unverified Commit 7c35e133 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Refactored set_cell_anchors() in AnchorGenerator (#3755)

* Refactored set_cell_anchors() in AnchorGenerator

* Addressed review comment

* Fixed test failure
parent 58afa511
import torch
from torch import nn, Tensor
from typing import List, Optional
from typing import List
from .image_list import ImageList
......@@ -27,7 +27,7 @@ class AnchorGenerator(nn.Module):
"""
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"cell_anchors": List[torch.Tensor],
}
def __init__(
......@@ -47,7 +47,8 @@ class AnchorGenerator(nn.Module):
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
for size, aspect_ratio in zip(sizes, aspect_ratios)]
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
......@@ -67,24 +68,8 @@ class AnchorGenerator(nn.Module):
return base_anchors.round()
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device)
for cell_anchor in self.cell_anchors]
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
......@@ -130,7 +115,7 @@ class AnchorGenerator(nn.Module):
return anchors
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
......@@ -138,7 +123,7 @@ class AnchorGenerator(nn.Module):
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)):
for _ in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment