Unverified Commit 58afa511 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Removed caching from AnchorGenerator (#3745)

parent cac8a97b
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from typing import List, Optional, Dict from typing import List, Optional
from .image_list import ImageList from .image_list import ImageList
...@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module): ...@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
__annotations__ = { __annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]], "cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
} }
def __init__( def __init__(
...@@ -49,7 +48,6 @@ class AnchorGenerator(nn.Module): ...@@ -49,7 +48,6 @@ class AnchorGenerator(nn.Module):
self.sizes = sizes self.sizes = sizes
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.cell_anchors = None self.cell_anchors = None
self._cache = {}
# TODO: https://github.com/pytorch/pytorch/issues/26792 # TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
...@@ -131,14 +129,6 @@ class AnchorGenerator(nn.Module): ...@@ -131,14 +129,6 @@ class AnchorGenerator(nn.Module):
return anchors return anchors
def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
...@@ -146,12 +136,10 @@ class AnchorGenerator(nn.Module): ...@@ -146,12 +136,10 @@ class AnchorGenerator(nn.Module):
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
anchors: List[List[torch.Tensor]] = [] anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)): for i in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image) anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors return anchors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment