Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
2611f5cc
"vscode:/vscode.git/clone" did not exist on "995bbcb9aa708d76e95a0014a0a4b991c1f7c084"
Unverified
Commit
2611f5cc
authored
Mar 10, 2020
by
Soham Tamba
Committed by
GitHub
Mar 10, 2020
Browse files
Commented AnchorGenerator (#1941)
parent
7d1cd1de
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
0 deletions
+9
-0
torchvision/models/detection/rpn.py
torchvision/models/detection/rpn.py
+9
-0
No files found.
torchvision/models/detection/rpn.py
View file @
2611f5cc
...
@@ -74,6 +74,8 @@ class AnchorGenerator(nn.Module):
...
@@ -74,6 +74,8 @@ class AnchorGenerator(nn.Module):
self
.
_cache
=
{}
self
.
_cache
=
{}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
def
generate_anchors
(
self
,
scales
,
aspect_ratios
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
def
generate_anchors
(
self
,
scales
,
aspect_ratios
,
dtype
=
torch
.
float32
,
device
=
"cpu"
):
# type: (List[int], List[float], int, Device) # noqa: F821
# type: (List[int], List[float], int, Device) # noqa: F821
scales
=
torch
.
as_tensor
(
scales
,
dtype
=
dtype
,
device
=
device
)
scales
=
torch
.
as_tensor
(
scales
,
dtype
=
dtype
,
device
=
device
)
...
@@ -111,6 +113,8 @@ class AnchorGenerator(nn.Module):
...
@@ -111,6 +113,8 @@ class AnchorGenerator(nn.Module):
def
num_anchors_per_location
(
self
):
def
num_anchors_per_location
(
self
):
return
[
len
(
s
)
*
len
(
a
)
for
s
,
a
in
zip
(
self
.
sizes
,
self
.
aspect_ratios
)]
return
[
len
(
s
)
*
len
(
a
)
for
s
,
a
in
zip
(
self
.
sizes
,
self
.
aspect_ratios
)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def
grid_anchors
(
self
,
grid_sizes
,
strides
):
def
grid_anchors
(
self
,
grid_sizes
,
strides
):
# type: (List[List[int]], List[List[int]])
# type: (List[List[int]], List[List[int]])
anchors
=
[]
anchors
=
[]
...
@@ -127,6 +131,8 @@ class AnchorGenerator(nn.Module):
...
@@ -127,6 +131,8 @@ class AnchorGenerator(nn.Module):
stride_width
=
torch
.
tensor
(
stride_width
,
dtype
=
torch
.
float32
)
stride_width
=
torch
.
tensor
(
stride_width
,
dtype
=
torch
.
float32
)
stride_height
=
torch
.
tensor
(
stride_height
,
dtype
=
torch
.
float32
)
stride_height
=
torch
.
tensor
(
stride_height
,
dtype
=
torch
.
float32
)
device
=
base_anchors
.
device
device
=
base_anchors
.
device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x
=
torch
.
arange
(
shifts_x
=
torch
.
arange
(
0
,
grid_width
,
dtype
=
torch
.
float32
,
device
=
device
0
,
grid_width
,
dtype
=
torch
.
float32
,
device
=
device
)
*
stride_width
)
*
stride_width
...
@@ -138,6 +144,8 @@ class AnchorGenerator(nn.Module):
...
@@ -138,6 +144,8 @@ class AnchorGenerator(nn.Module):
shift_y
=
shift_y
.
reshape
(
-
1
)
shift_y
=
shift_y
.
reshape
(
-
1
)
shifts
=
torch
.
stack
((
shift_x
,
shift_y
,
shift_x
,
shift_y
),
dim
=
1
)
shifts
=
torch
.
stack
((
shift_x
,
shift_y
,
shift_x
,
shift_y
),
dim
=
1
)
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors
.
append
(
anchors
.
append
(
(
shifts
.
view
(
-
1
,
1
,
4
)
+
base_anchors
.
view
(
1
,
-
1
,
4
)).
reshape
(
-
1
,
4
)
(
shifts
.
view
(
-
1
,
1
,
4
)
+
base_anchors
.
view
(
1
,
-
1
,
4
)).
reshape
(
-
1
,
4
)
)
)
...
@@ -158,6 +166,7 @@ class AnchorGenerator(nn.Module):
...
@@ -158,6 +166,7 @@ class AnchorGenerator(nn.Module):
grid_sizes
=
list
([
feature_map
.
shape
[
-
2
:]
for
feature_map
in
feature_maps
])
grid_sizes
=
list
([
feature_map
.
shape
[
-
2
:]
for
feature_map
in
feature_maps
])
image_size
=
image_list
.
tensors
.
shape
[
-
2
:]
image_size
=
image_list
.
tensors
.
shape
[
-
2
:]
strides
=
[[
int
(
image_size
[
0
]
/
g
[
0
]),
int
(
image_size
[
1
]
/
g
[
1
])]
for
g
in
grid_sizes
]
strides
=
[[
int
(
image_size
[
0
]
/
g
[
0
]),
int
(
image_size
[
1
]
/
g
[
1
])]
for
g
in
grid_sizes
]
dtype
,
device
=
feature_maps
[
0
].
dtype
,
feature_maps
[
0
].
device
dtype
,
device
=
feature_maps
[
0
].
dtype
,
feature_maps
[
0
].
device
self
.
set_cell_anchors
(
dtype
,
device
)
self
.
set_cell_anchors
(
dtype
,
device
)
anchors_over_all_feature_maps
=
self
.
cached_grid_anchors
(
grid_sizes
,
strides
)
anchors_over_all_feature_maps
=
self
.
cached_grid_anchors
(
grid_sizes
,
strides
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment