You need to sign in or sign up before continuing.
Unverified Commit 7c35e133 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Refactored set_cell_anchors() in AnchorGenerator (#3755)

* Refactored set_cell_anchors() in AnchorGenerator

* Addressed review comment

* Fixed test failure
parent 58afa511
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from typing import List, Optional from typing import List
from .image_list import ImageList from .image_list import ImageList
...@@ -27,7 +27,7 @@ class AnchorGenerator(nn.Module): ...@@ -27,7 +27,7 @@ class AnchorGenerator(nn.Module):
""" """
__annotations__ = { __annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]], "cell_anchors": List[torch.Tensor],
} }
def __init__( def __init__(
...@@ -47,7 +47,8 @@ class AnchorGenerator(nn.Module): ...@@ -47,7 +47,8 @@ class AnchorGenerator(nn.Module):
self.sizes = sizes self.sizes = sizes
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.cell_anchors = None self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
for size, aspect_ratio in zip(sizes, aspect_ratios)]
# TODO: https://github.com/pytorch/pytorch/issues/26792 # TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
...@@ -67,24 +68,8 @@ class AnchorGenerator(nn.Module): ...@@ -67,24 +68,8 @@ class AnchorGenerator(nn.Module):
return base_anchors.round() return base_anchors.round()
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
if self.cell_anchors is not None: self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device)
cell_anchors = self.cell_anchors for cell_anchor in self.cell_anchors]
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self): def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
...@@ -130,7 +115,7 @@ class AnchorGenerator(nn.Module): ...@@ -130,7 +115,7 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
...@@ -138,7 +123,7 @@ class AnchorGenerator(nn.Module): ...@@ -138,7 +123,7 @@ class AnchorGenerator(nn.Module):
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
anchors: List[List[torch.Tensor]] = [] anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)): for _ in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image) anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment