Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
4b2ad55f
Unverified
Commit
4b2ad55f
authored
Nov 17, 2021
by
Joao Gomes
Committed by
GitHub
Nov 17, 2021
Browse files
Refactor poolers (#4951)
* refactoring methods from MultiScaleRoIAlign
parent
9841a907
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
26 deletions
+25
-26
torchvision/ops/poolers.py
torchvision/ops/poolers.py
+25
-26
No files found.
torchvision/ops/poolers.py
View file @
4b2ad55f
...
@@ -83,6 +83,29 @@ class LevelMapper:
...
@@ -83,6 +83,29 @@ class LevelMapper:
return
(
target_lvls
.
to
(
torch
.
int64
)
-
self
.
k_min
).
to
(
torch
.
int64
)
return
(
target_lvls
.
to
(
torch
.
int64
)
-
self
.
k_min
).
to
(
torch
.
int64
)
def
_convert_to_roi_format
(
boxes
:
List
[
Tensor
])
->
Tensor
:
concat_boxes
=
torch
.
cat
(
boxes
,
dim
=
0
)
device
,
dtype
=
concat_boxes
.
device
,
concat_boxes
.
dtype
ids
=
torch
.
cat
(
[
torch
.
full_like
(
b
[:,
:
1
],
i
,
dtype
=
dtype
,
layout
=
torch
.
strided
,
device
=
device
)
for
i
,
b
in
enumerate
(
boxes
)],
dim
=
0
,
)
rois
=
torch
.
cat
([
ids
,
concat_boxes
],
dim
=
1
)
return
rois
def
_infer_scale
(
feature
:
Tensor
,
original_size
:
List
[
int
])
->
float
:
# assumption: the scale is of the form 2 ** (-k), with k integer
size
=
feature
.
shape
[
-
2
:]
possible_scales
:
List
[
float
]
=
[]
for
s1
,
s2
in
zip
(
size
,
original_size
):
approx_scale
=
float
(
s1
)
/
float
(
s2
)
scale
=
2
**
float
(
torch
.
tensor
(
approx_scale
).
log2
().
round
())
possible_scales
.
append
(
scale
)
assert
possible_scales
[
0
]
==
possible_scales
[
1
]
return
possible_scales
[
0
]
class
MultiScaleRoIAlign
(
nn
.
Module
):
class
MultiScaleRoIAlign
(
nn
.
Module
):
"""
"""
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
...
@@ -142,30 +165,6 @@ class MultiScaleRoIAlign(nn.Module):
...
@@ -142,30 +165,6 @@ class MultiScaleRoIAlign(nn.Module):
self
.
canonical_scale
=
canonical_scale
self
.
canonical_scale
=
canonical_scale
self
.
canonical_level
=
canonical_level
self
.
canonical_level
=
canonical_level
def
convert_to_roi_format
(
self
,
boxes
:
List
[
Tensor
])
->
Tensor
:
concat_boxes
=
torch
.
cat
(
boxes
,
dim
=
0
)
device
,
dtype
=
concat_boxes
.
device
,
concat_boxes
.
dtype
ids
=
torch
.
cat
(
[
torch
.
full_like
(
b
[:,
:
1
],
i
,
dtype
=
dtype
,
layout
=
torch
.
strided
,
device
=
device
)
for
i
,
b
in
enumerate
(
boxes
)
],
dim
=
0
,
)
rois
=
torch
.
cat
([
ids
,
concat_boxes
],
dim
=
1
)
return
rois
def
infer_scale
(
self
,
feature
:
Tensor
,
original_size
:
List
[
int
])
->
float
:
# assumption: the scale is of the form 2 ** (-k), with k integer
size
=
feature
.
shape
[
-
2
:]
possible_scales
:
List
[
float
]
=
[]
for
s1
,
s2
in
zip
(
size
,
original_size
):
approx_scale
=
float
(
s1
)
/
float
(
s2
)
scale
=
2
**
float
(
torch
.
tensor
(
approx_scale
).
log2
().
round
())
possible_scales
.
append
(
scale
)
assert
possible_scales
[
0
]
==
possible_scales
[
1
]
return
possible_scales
[
0
]
def
setup_scales
(
def
setup_scales
(
self
,
self
,
features
:
List
[
Tensor
],
features
:
List
[
Tensor
],
...
@@ -179,7 +178,7 @@ class MultiScaleRoIAlign(nn.Module):
...
@@ -179,7 +178,7 @@ class MultiScaleRoIAlign(nn.Module):
max_y
=
max
(
shape
[
1
],
max_y
)
max_y
=
max
(
shape
[
1
],
max_y
)
original_input_shape
=
(
max_x
,
max_y
)
original_input_shape
=
(
max_x
,
max_y
)
scales
=
[
self
.
infer_scale
(
feat
,
original_input_shape
)
for
feat
in
features
]
scales
=
[
_
infer_scale
(
feat
,
original_input_shape
)
for
feat
in
features
]
# get the levels in the feature map by leveraging the fact that the network always
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
# downsamples by a factor of 2 at each level.
lvl_min
=
-
torch
.
log2
(
torch
.
tensor
(
scales
[
0
],
dtype
=
torch
.
float32
)).
item
()
lvl_min
=
-
torch
.
log2
(
torch
.
tensor
(
scales
[
0
],
dtype
=
torch
.
float32
)).
item
()
...
@@ -216,7 +215,7 @@ class MultiScaleRoIAlign(nn.Module):
...
@@ -216,7 +215,7 @@ class MultiScaleRoIAlign(nn.Module):
if
k
in
self
.
featmap_names
:
if
k
in
self
.
featmap_names
:
x_filtered
.
append
(
v
)
x_filtered
.
append
(
v
)
num_levels
=
len
(
x_filtered
)
num_levels
=
len
(
x_filtered
)
rois
=
self
.
convert_to_roi_format
(
boxes
)
rois
=
_
convert_to_roi_format
(
boxes
)
if
self
.
scales
is
None
:
if
self
.
scales
is
None
:
self
.
setup_scales
(
x_filtered
,
image_shapes
)
self
.
setup_scales
(
x_filtered
,
image_shapes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment