Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
58afa511
Unverified
Commit
58afa511
authored
Apr 29, 2021
by
Prabhat Roy
Committed by
GitHub
Apr 29, 2021
Browse files
Removed caching from AnchorGenerator (#3745)
parent
cac8a97b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
14 deletions
+2
-14
torchvision/models/detection/anchor_utils.py
torchvision/models/detection/anchor_utils.py
+2
-14
No files found.
torchvision/models/detection/anchor_utils.py
View file @
58afa511
import
torch
import
torch
from
torch
import
nn
,
Tensor
from
torch
import
nn
,
Tensor
from
typing
import
List
,
Optional
,
Dict
from
typing
import
List
,
Optional
from
.image_list
import
ImageList
from
.image_list
import
ImageList
...
@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
...
@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
__annotations__
=
{
__annotations__
=
{
"cell_anchors"
:
Optional
[
List
[
torch
.
Tensor
]],
"cell_anchors"
:
Optional
[
List
[
torch
.
Tensor
]],
"_cache"
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
}
}
def
__init__
(
def
__init__
(
...
@@ -49,7 +48,6 @@ class AnchorGenerator(nn.Module):
...
@@ -49,7 +48,6 @@ class AnchorGenerator(nn.Module):
self
.
sizes
=
sizes
self
.
sizes
=
sizes
self
.
aspect_ratios
=
aspect_ratios
self
.
aspect_ratios
=
aspect_ratios
self
.
cell_anchors
=
None
self
.
cell_anchors
=
None
self
.
_cache
=
{}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
...
@@ -131,14 +129,6 @@ class AnchorGenerator(nn.Module):
...
@@ -131,14 +129,6 @@ class AnchorGenerator(nn.Module):
return
anchors
return
anchors
def
cached_grid_anchors
(
self
,
grid_sizes
:
List
[
List
[
int
]],
strides
:
List
[
List
[
Tensor
]])
->
List
[
Tensor
]:
key
=
str
(
grid_sizes
)
+
str
(
strides
)
if
key
in
self
.
_cache
:
return
self
.
_cache
[
key
]
anchors
=
self
.
grid_anchors
(
grid_sizes
,
strides
)
self
.
_cache
[
key
]
=
anchors
return
anchors
def
forward
(
self
,
image_list
:
ImageList
,
feature_maps
:
List
[
Tensor
])
->
List
[
Tensor
]:
def
forward
(
self
,
image_list
:
ImageList
,
feature_maps
:
List
[
Tensor
])
->
List
[
Tensor
]:
grid_sizes
=
list
([
feature_map
.
shape
[
-
2
:]
for
feature_map
in
feature_maps
])
grid_sizes
=
list
([
feature_map
.
shape
[
-
2
:]
for
feature_map
in
feature_maps
])
image_size
=
image_list
.
tensors
.
shape
[
-
2
:]
image_size
=
image_list
.
tensors
.
shape
[
-
2
:]
...
@@ -146,12 +136,10 @@ class AnchorGenerator(nn.Module):
...
@@ -146,12 +136,10 @@ class AnchorGenerator(nn.Module):
strides
=
[[
torch
.
tensor
(
image_size
[
0
]
//
g
[
0
],
dtype
=
torch
.
int64
,
device
=
device
),
strides
=
[[
torch
.
tensor
(
image_size
[
0
]
//
g
[
0
],
dtype
=
torch
.
int64
,
device
=
device
),
torch
.
tensor
(
image_size
[
1
]
//
g
[
1
],
dtype
=
torch
.
int64
,
device
=
device
)]
for
g
in
grid_sizes
]
torch
.
tensor
(
image_size
[
1
]
//
g
[
1
],
dtype
=
torch
.
int64
,
device
=
device
)]
for
g
in
grid_sizes
]
self
.
set_cell_anchors
(
dtype
,
device
)
self
.
set_cell_anchors
(
dtype
,
device
)
anchors_over_all_feature_maps
=
self
.
cached_
grid_anchors
(
grid_sizes
,
strides
)
anchors_over_all_feature_maps
=
self
.
grid_anchors
(
grid_sizes
,
strides
)
anchors
:
List
[
List
[
torch
.
Tensor
]]
=
[]
anchors
:
List
[
List
[
torch
.
Tensor
]]
=
[]
for
i
in
range
(
len
(
image_list
.
image_sizes
)):
for
i
in
range
(
len
(
image_list
.
image_sizes
)):
anchors_in_image
=
[
anchors_per_feature_map
for
anchors_per_feature_map
in
anchors_over_all_feature_maps
]
anchors_in_image
=
[
anchors_per_feature_map
for
anchors_per_feature_map
in
anchors_over_all_feature_maps
]
anchors
.
append
(
anchors_in_image
)
anchors
.
append
(
anchors_in_image
)
anchors
=
[
torch
.
cat
(
anchors_per_image
)
for
anchors_per_image
in
anchors
]
anchors
=
[
torch
.
cat
(
anchors_per_image
)
for
anchors_per_image
in
anchors
]
# Clear the cache in case that memory leaks.
self
.
_cache
.
clear
()
return
anchors
return
anchors
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment