Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
58afa511
"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c3baf4332b3f2a453fb969c496b2ec71b75c0f2d"
Unverified
Commit
58afa511
authored
Apr 29, 2021
by
Prabhat Roy
Committed by
GitHub
Apr 29, 2021
Browse files
Removed caching from AnchorGenerator (#3745)
parent
cac8a97b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
14 deletions
+2
-14
torchvision/models/detection/anchor_utils.py
torchvision/models/detection/anchor_utils.py
+2
-14
No files found.
torchvision/models/detection/anchor_utils.py
View file @
58afa511
import
torch
import
torch
from
torch
import
nn
,
Tensor
from
torch
import
nn
,
Tensor
from
typing
import
List
,
Optional
,
Dict
from
typing
import
List
,
Optional
from
.image_list
import
ImageList
from
.image_list
import
ImageList
...
@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
...
@@ -28,7 +28,6 @@ class AnchorGenerator(nn.Module):
__annotations__
=
{
__annotations__
=
{
"cell_anchors"
:
Optional
[
List
[
torch
.
Tensor
]],
"cell_anchors"
:
Optional
[
List
[
torch
.
Tensor
]],
"_cache"
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
}
}
def
__init__
(
def
__init__
(
...
@@ -49,7 +48,6 @@ class AnchorGenerator(nn.Module):
...
@@ -49,7 +48,6 @@ class AnchorGenerator(nn.Module):
self
.
sizes
=
sizes
self
.
sizes
=
sizes
self
.
aspect_ratios
=
aspect_ratios
self
.
aspect_ratios
=
aspect_ratios
self
.
cell_anchors
=
None
self
.
cell_anchors
=
None
self
.
_cache
=
{}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
...
@@ -131,14 +129,6 @@ class AnchorGenerator(nn.Module):
...
@@ -131,14 +129,6 @@ class AnchorGenerator(nn.Module):
return
anchors
return
anchors
def
cached_grid_anchors
(
self
,
grid_sizes
:
List
[
List
[
int
]],
strides
:
List
[
List
[
Tensor
]])
->
List
[
Tensor
]:
key
=
str
(
grid_sizes
)
+
str
(
strides
)
if
key
in
self
.
_cache
:
return
self
.
_cache
[
key
]
anchors
=
self
.
grid_anchors
(
grid_sizes
,
strides
)
self
.
_cache
[
key
]
=
anchors
return
anchors
def
forward
(
self
,
image_list
:
ImageList
,
feature_maps
:
List
[
Tensor
])
->
List
[
Tensor
]:
def
forward
(
self
,
image_list
:
ImageList
,
feature_maps
:
List
[
Tensor
])
->
List
[
Tensor
]:
grid_sizes
=
list
([
feature_map
.
shape
[
-
2
:]
for
feature_map
in
feature_maps
])
grid_sizes
=
list
([
feature_map
.
shape
[
-
2
:]
for
feature_map
in
feature_maps
])
image_size
=
image_list
.
tensors
.
shape
[
-
2
:]
image_size
=
image_list
.
tensors
.
shape
[
-
2
:]
...
@@ -146,12 +136,10 @@ class AnchorGenerator(nn.Module):
...
@@ -146,12 +136,10 @@ class AnchorGenerator(nn.Module):
strides
=
[[
torch
.
tensor
(
image_size
[
0
]
//
g
[
0
],
dtype
=
torch
.
int64
,
device
=
device
),
strides
=
[[
torch
.
tensor
(
image_size
[
0
]
//
g
[
0
],
dtype
=
torch
.
int64
,
device
=
device
),
torch
.
tensor
(
image_size
[
1
]
//
g
[
1
],
dtype
=
torch
.
int64
,
device
=
device
)]
for
g
in
grid_sizes
]
torch
.
tensor
(
image_size
[
1
]
//
g
[
1
],
dtype
=
torch
.
int64
,
device
=
device
)]
for
g
in
grid_sizes
]
self
.
set_cell_anchors
(
dtype
,
device
)
self
.
set_cell_anchors
(
dtype
,
device
)
anchors_over_all_feature_maps
=
self
.
cached_
grid_anchors
(
grid_sizes
,
strides
)
anchors_over_all_feature_maps
=
self
.
grid_anchors
(
grid_sizes
,
strides
)
anchors
:
List
[
List
[
torch
.
Tensor
]]
=
[]
anchors
:
List
[
List
[
torch
.
Tensor
]]
=
[]
for
i
in
range
(
len
(
image_list
.
image_sizes
)):
for
i
in
range
(
len
(
image_list
.
image_sizes
)):
anchors_in_image
=
[
anchors_per_feature_map
for
anchors_per_feature_map
in
anchors_over_all_feature_maps
]
anchors_in_image
=
[
anchors_per_feature_map
for
anchors_per_feature_map
in
anchors_over_all_feature_maps
]
anchors
.
append
(
anchors_in_image
)
anchors
.
append
(
anchors_in_image
)
anchors
=
[
torch
.
cat
(
anchors_per_image
)
for
anchors_per_image
in
anchors
]
anchors
=
[
torch
.
cat
(
anchors_per_image
)
for
anchors_per_image
in
anchors
]
# Clear the cache in case that memory leaks.
self
.
_cache
.
clear
()
return
anchors
return
anchors
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment