Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
a7a36756
Unverified
Commit
a7a36756
authored
Nov 30, 2021
by
Joao Gomes
Committed by
GitHub
Nov 30, 2021
Browse files
Feature extraction default arguments - ops (#4810)
making torchvision ops leaf nodes by default
parent
39cf02a6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
303 additions
and
94 deletions
+303
-94
test/test_ops.py
test/test_ops.py
+121
-4
torchvision/models/feature_extraction.py
torchvision/models/feature_extraction.py
+35
-3
torchvision/ops/poolers.py
torchvision/ops/poolers.py
+147
-87
No files found.
test/test_ops.py
View file @
a7a36756
...
@@ -7,12 +7,54 @@ from typing import Tuple
...
@@ -7,12 +7,54 @@ from typing import Tuple
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
import
torch.fx
from
common_utils
import
needs_cuda
,
cpu_and_gpu
,
assert_equal
from
common_utils
import
needs_cuda
,
cpu_and_gpu
,
assert_equal
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
,
Tensor
from
torch
import
nn
,
Tensor
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
from
torchvision
import
models
,
ops
from
torchvision
import
models
,
ops
from
torchvision.models.feature_extraction
import
get_graph_node_names
class
RoIOpTesterModuleWrapper
(
nn
.
Module
):
def
__init__
(
self
,
obj
):
super
().
__init__
()
self
.
layer
=
obj
self
.
n_inputs
=
2
def
forward
(
self
,
a
,
b
):
self
.
layer
(
a
,
b
)
class
MultiScaleRoIAlignModuleWrapper
(
nn
.
Module
):
def
__init__
(
self
,
obj
):
super
().
__init__
()
self
.
layer
=
obj
self
.
n_inputs
=
3
def
forward
(
self
,
a
,
b
,
c
):
self
.
layer
(
a
,
b
,
c
)
class
DeformConvModuleWrapper
(
nn
.
Module
):
def
__init__
(
self
,
obj
):
super
().
__init__
()
self
.
layer
=
obj
self
.
n_inputs
=
3
def
forward
(
self
,
a
,
b
,
c
):
self
.
layer
(
a
,
b
,
c
)
class
StochasticDepthWrapper
(
nn
.
Module
):
def
__init__
(
self
,
obj
):
super
().
__init__
()
self
.
layer
=
obj
self
.
n_inputs
=
1
def
forward
(
self
,
a
):
self
.
layer
(
a
)
class
RoIOpTester
(
ABC
):
class
RoIOpTester
(
ABC
):
...
@@ -46,6 +88,15 @@ class RoIOpTester(ABC):
...
@@ -46,6 +88,15 @@ class RoIOpTester(ABC):
tol
=
1e-3
if
(
x_dtype
is
torch
.
half
or
rois_dtype
is
torch
.
half
)
else
1e-5
tol
=
1e-3
if
(
x_dtype
is
torch
.
half
or
rois_dtype
is
torch
.
half
)
else
1e-5
torch
.
testing
.
assert_close
(
gt_y
.
to
(
y
),
y
,
rtol
=
tol
,
atol
=
tol
)
torch
.
testing
.
assert_close
(
gt_y
.
to
(
y
),
y
,
rtol
=
tol
,
atol
=
tol
)
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_is_leaf_node
(
self
,
device
):
op_obj
=
self
.
make_obj
(
wrap
=
True
).
to
(
device
=
device
)
graph_node_names
=
get_graph_node_names
(
op_obj
)
assert
len
(
graph_node_names
)
==
2
assert
len
(
graph_node_names
[
0
])
==
len
(
graph_node_names
[
1
])
assert
len
(
graph_node_names
[
0
])
==
1
+
op_obj
.
n_inputs
@
pytest
.
mark
.
parametrize
(
"seed"
,
range
(
10
))
@
pytest
.
mark
.
parametrize
(
"seed"
,
range
(
10
))
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
...
@@ -91,6 +142,10 @@ class RoIOpTester(ABC):
...
@@ -91,6 +142,10 @@ class RoIOpTester(ABC):
def
fn
(
*
args
,
**
kwargs
):
def
fn
(
*
args
,
**
kwargs
):
pass
pass
@
abstractmethod
def
make_obj
(
*
args
,
**
kwargs
):
pass
@
abstractmethod
@
abstractmethod
def
get_script_fn
(
*
args
,
**
kwargs
):
def
get_script_fn
(
*
args
,
**
kwargs
):
pass
pass
...
@@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester):
...
@@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
return
ops
.
RoIPool
((
pool_h
,
pool_w
),
spatial_scale
)(
x
,
rois
)
return
ops
.
RoIPool
((
pool_h
,
pool_w
),
spatial_scale
)(
x
,
rois
)
def
make_obj
(
self
,
pool_h
=
5
,
pool_w
=
5
,
spatial_scale
=
1
,
wrap
=
False
):
obj
=
ops
.
RoIPool
((
pool_h
,
pool_w
),
spatial_scale
)
return
RoIOpTesterModuleWrapper
(
obj
)
if
wrap
else
obj
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
scriped
=
torch
.
jit
.
script
(
ops
.
roi_pool
)
scriped
=
torch
.
jit
.
script
(
ops
.
roi_pool
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
...
@@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester):
...
@@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
return
ops
.
PSRoIPool
((
pool_h
,
pool_w
),
1
)(
x
,
rois
)
return
ops
.
PSRoIPool
((
pool_h
,
pool_w
),
1
)(
x
,
rois
)
def
make_obj
(
self
,
pool_h
=
5
,
pool_w
=
5
,
spatial_scale
=
1
,
wrap
=
False
):
obj
=
ops
.
PSRoIPool
((
pool_h
,
pool_w
),
spatial_scale
)
return
RoIOpTesterModuleWrapper
(
obj
)
if
wrap
else
obj
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
scriped
=
torch
.
jit
.
script
(
ops
.
ps_roi_pool
)
scriped
=
torch
.
jit
.
script
(
ops
.
ps_roi_pool
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
...
@@ -223,6 +286,12 @@ class TestRoIAlign(RoIOpTester):
...
@@ -223,6 +286,12 @@ class TestRoIAlign(RoIOpTester):
(
pool_h
,
pool_w
),
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
,
aligned
=
aligned
(
pool_h
,
pool_w
),
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
,
aligned
=
aligned
)(
x
,
rois
)
)(
x
,
rois
)
def
make_obj
(
self
,
pool_h
=
5
,
pool_w
=
5
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
aligned
=
False
,
wrap
=
False
):
obj
=
ops
.
RoIAlign
(
(
pool_h
,
pool_w
),
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
,
aligned
=
aligned
)
return
RoIOpTesterModuleWrapper
(
obj
)
if
wrap
else
obj
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
scriped
=
torch
.
jit
.
script
(
ops
.
roi_align
)
scriped
=
torch
.
jit
.
script
(
ops
.
roi_align
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
...
@@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester):
...
@@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
def
fn
(
self
,
x
,
rois
,
pool_h
,
pool_w
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
**
kwargs
):
return
ops
.
PSRoIAlign
((
pool_h
,
pool_w
),
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
)(
x
,
rois
)
return
ops
.
PSRoIAlign
((
pool_h
,
pool_w
),
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
)(
x
,
rois
)
def
make_obj
(
self
,
pool_h
=
5
,
pool_w
=
5
,
spatial_scale
=
1
,
sampling_ratio
=-
1
,
wrap
=
False
):
obj
=
ops
.
PSRoIAlign
((
pool_h
,
pool_w
),
spatial_scale
=
spatial_scale
,
sampling_ratio
=
sampling_ratio
)
return
RoIOpTesterModuleWrapper
(
obj
)
if
wrap
else
obj
def
get_script_fn
(
self
,
rois
,
pool_size
):
def
get_script_fn
(
self
,
rois
,
pool_size
):
scriped
=
torch
.
jit
.
script
(
ops
.
ps_roi_align
)
scriped
=
torch
.
jit
.
script
(
ops
.
ps_roi_align
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
return
lambda
x
:
scriped
(
x
,
rois
,
pool_size
)
...
@@ -422,12 +495,18 @@ class TestPSRoIAlign(RoIOpTester):
...
@@ -422,12 +495,18 @@ class TestPSRoIAlign(RoIOpTester):
class
TestMultiScaleRoIAlign
:
class
TestMultiScaleRoIAlign
:
def
make_obj
(
self
,
fmap_names
=
None
,
output_size
=
(
7
,
7
),
sampling_ratio
=
2
,
wrap
=
False
):
if
fmap_names
is
None
:
fmap_names
=
[
"0"
]
obj
=
ops
.
poolers
.
MultiScaleRoIAlign
(
fmap_names
,
output_size
,
sampling_ratio
)
return
MultiScaleRoIAlignModuleWrapper
(
obj
)
if
wrap
else
obj
def
test_msroialign_repr
(
self
):
def
test_msroialign_repr
(
self
):
fmap_names
=
[
"0"
]
fmap_names
=
[
"0"
]
output_size
=
(
7
,
7
)
output_size
=
(
7
,
7
)
sampling_ratio
=
2
sampling_ratio
=
2
# Pass mock feature map names
# Pass mock feature map names
t
=
ops
.
poolers
.
MultiScaleRoIAlign
(
fmap_names
,
output_size
,
sampling_ratio
)
t
=
self
.
make_obj
(
fmap_names
,
output_size
,
sampling_ratio
,
wrap
=
False
)
# Check integrity of object __repr__ attribute
# Check integrity of object __repr__ attribute
expected_string
=
(
expected_string
=
(
...
@@ -436,6 +515,15 @@ class TestMultiScaleRoIAlign:
...
@@ -436,6 +515,15 @@ class TestMultiScaleRoIAlign:
)
)
assert
repr
(
t
)
==
expected_string
assert
repr
(
t
)
==
expected_string
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_is_leaf_node
(
self
,
device
):
op_obj
=
self
.
make_obj
(
wrap
=
True
).
to
(
device
=
device
)
graph_node_names
=
get_graph_node_names
(
op_obj
)
assert
len
(
graph_node_names
)
==
2
assert
len
(
graph_node_names
[
0
])
==
len
(
graph_node_names
[
1
])
assert
len
(
graph_node_names
[
0
])
==
1
+
op_obj
.
n_inputs
class
TestNMS
:
class
TestNMS
:
def
_reference_nms
(
self
,
boxes
,
scores
,
iou_threshold
):
def
_reference_nms
(
self
,
boxes
,
scores
,
iou_threshold
):
...
@@ -693,6 +781,21 @@ class TestDeformConv:
...
@@ -693,6 +781,21 @@ class TestDeformConv:
return
x
,
weight
,
offset
,
mask
,
bias
,
stride
,
pad
,
dilation
return
x
,
weight
,
offset
,
mask
,
bias
,
stride
,
pad
,
dilation
def
make_obj
(
self
,
in_channels
=
6
,
out_channels
=
2
,
kernel_size
=
(
3
,
2
),
groups
=
2
,
wrap
=
False
):
obj
=
ops
.
DeformConv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
(
2
,
1
),
padding
=
(
1
,
0
),
dilation
=
(
2
,
1
),
groups
=
groups
)
return
DeformConvModuleWrapper
(
obj
)
if
wrap
else
obj
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
def
test_is_leaf_node
(
self
,
device
):
op_obj
=
self
.
make_obj
(
wrap
=
True
).
to
(
device
=
device
)
graph_node_names
=
get_graph_node_names
(
op_obj
)
assert
len
(
graph_node_names
)
==
2
assert
len
(
graph_node_names
[
0
])
==
len
(
graph_node_names
[
1
])
assert
len
(
graph_node_names
[
0
])
==
1
+
op_obj
.
n_inputs
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"device"
,
cpu_and_gpu
())
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"contiguous"
,
(
True
,
False
))
@
pytest
.
mark
.
parametrize
(
"batch_sz"
,
(
0
,
33
))
@
pytest
.
mark
.
parametrize
(
"batch_sz"
,
(
0
,
33
))
...
@@ -705,9 +808,9 @@ class TestDeformConv:
...
@@ -705,9 +808,9 @@ class TestDeformConv:
groups
=
2
groups
=
2
tol
=
2e-3
if
dtype
is
torch
.
half
else
1e-5
tol
=
2e-3
if
dtype
is
torch
.
half
else
1e-5
layer
=
ops
.
DeformConv2d
(
layer
=
self
.
make_obj
(
in_channels
,
out_channels
,
kernel_size
,
groups
,
wrap
=
False
).
to
(
in_channels
,
out_channels
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
device
=
x
.
device
,
dtype
=
dtype
)
.
to
(
device
=
x
.
device
,
dtype
=
dtype
)
)
res
=
layer
(
x
,
offset
,
mask
)
res
=
layer
(
x
,
offset
,
mask
)
weight
=
layer
.
weight
.
data
weight
=
layer
.
weight
.
data
...
@@ -1200,6 +1303,20 @@ class TestStochasticDepth:
...
@@ -1200,6 +1303,20 @@ class TestStochasticDepth:
elif
p
==
1
:
elif
p
==
1
:
assert
out
.
equal
(
torch
.
zeros_like
(
x
))
assert
out
.
equal
(
torch
.
zeros_like
(
x
))
def
make_obj
(
self
,
p
,
mode
,
wrap
=
False
):
obj
=
ops
.
StochasticDepth
(
p
,
mode
)
return
StochasticDepthWrapper
(
obj
)
if
wrap
else
obj
@
pytest
.
mark
.
parametrize
(
"p"
,
(
0
,
1
))
@
pytest
.
mark
.
parametrize
(
"mode"
,
[
"batch"
,
"row"
])
def
test_is_leaf_node
(
self
,
p
,
mode
):
op_obj
=
self
.
make_obj
(
p
,
mode
,
wrap
=
True
)
graph_node_names
=
get_graph_node_names
(
op_obj
)
assert
len
(
graph_node_names
)
==
2
assert
len
(
graph_node_names
[
0
])
==
len
(
graph_node_names
[
1
])
assert
len
(
graph_node_names
[
0
])
==
1
+
op_obj
.
n_inputs
class
TestUtils
:
class
TestUtils
:
@
pytest
.
mark
.
parametrize
(
"norm_layer"
,
[
None
,
nn
.
BatchNorm2d
,
nn
.
LayerNorm
])
@
pytest
.
mark
.
parametrize
(
"norm_layer"
,
[
None
,
nn
.
BatchNorm2d
,
nn
.
LayerNorm
])
...
...
torchvision/models/feature_extraction.py
View file @
a7a36756
import
inspect
import
math
import
re
import
re
import
warnings
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
copy
import
deepcopy
from
copy
import
deepcopy
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
Dict
,
Callable
,
List
,
Union
,
Optional
,
Tuple
from
typing
import
Dict
,
Callable
,
List
,
Union
,
Optional
,
Tuple
,
Any
import
torch
import
torch
import
torchvision
from
torch
import
fx
from
torch
import
fx
from
torch
import
nn
from
torch
import
nn
from
torch.fx.graph_module
import
_copy_attr
from
torch.fx.graph_module
import
_copy_attr
...
@@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT
...
@@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT
warnings
.
warn
(
msg
+
suggestion_msg
)
warnings
.
warn
(
msg
+
suggestion_msg
)
def
_get_leaf_modules_for_ops
()
->
List
[
type
]:
members
=
inspect
.
getmembers
(
torchvision
.
ops
)
result
=
[]
for
_
,
obj
in
members
:
if
inspect
.
isclass
(
obj
)
and
issubclass
(
obj
,
torch
.
nn
.
Module
):
result
.
append
(
obj
)
return
result
def
get_graph_node_names
(
def
get_graph_node_names
(
model
:
nn
.
Module
,
tracer_kwargs
:
Dict
=
{},
suppress_diff_warning
:
bool
=
False
model
:
nn
.
Module
,
tracer_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
suppress_diff_warning
:
bool
=
False
,
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
"""
"""
Dev utility to return node names in order of execution. See note on node
Dev utility to return node names in order of execution. See note on node
...
@@ -198,6 +212,7 @@ def get_graph_node_names(
...
@@ -198,6 +212,7 @@ def get_graph_node_names(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (they are eventually passed onto
``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
suppress_diff_warning (bool, optional): whether to suppress a warning
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
the graph. Defaults to False.
...
@@ -211,6 +226,14 @@ def get_graph_node_names(
...
@@ -211,6 +226,14 @@ def get_graph_node_names(
>>> model = torchvision.models.resnet18()
>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)
>>> train_nodes, eval_nodes = get_graph_node_names(model)
"""
"""
if
tracer_kwargs
is
None
:
tracer_kwargs
=
{
"autowrap_modules"
:
(
math
,
torchvision
.
ops
,
),
"leaf_modules"
:
_get_leaf_modules_for_ops
(),
}
is_training
=
model
.
training
is_training
=
model
.
training
train_tracer
=
NodePathTracer
(
**
tracer_kwargs
)
train_tracer
=
NodePathTracer
(
**
tracer_kwargs
)
train_tracer
.
trace
(
model
.
train
())
train_tracer
.
trace
(
model
.
train
())
...
@@ -294,7 +317,7 @@ def create_feature_extractor(
...
@@ -294,7 +317,7 @@ def create_feature_extractor(
return_nodes
:
Optional
[
Union
[
List
[
str
],
Dict
[
str
,
str
]]]
=
None
,
return_nodes
:
Optional
[
Union
[
List
[
str
],
Dict
[
str
,
str
]]]
=
None
,
train_return_nodes
:
Optional
[
Union
[
List
[
str
],
Dict
[
str
,
str
]]]
=
None
,
train_return_nodes
:
Optional
[
Union
[
List
[
str
],
Dict
[
str
,
str
]]]
=
None
,
eval_return_nodes
:
Optional
[
Union
[
List
[
str
],
Dict
[
str
,
str
]]]
=
None
,
eval_return_nodes
:
Optional
[
Union
[
List
[
str
],
Dict
[
str
,
str
]]]
=
None
,
tracer_kwargs
:
Dict
=
{}
,
tracer_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
suppress_diff_warning
:
bool
=
False
,
suppress_diff_warning
:
bool
=
False
,
)
->
fx
.
GraphModule
:
)
->
fx
.
GraphModule
:
"""
"""
...
@@ -353,6 +376,7 @@ def create_feature_extractor(
...
@@ -353,6 +376,7 @@ def create_feature_extractor(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (which passes them onto it's parent class
``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
suppress_diff_warning (bool, optional): whether to suppress a warning
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
the graph. Defaults to False.
...
@@ -397,6 +421,14 @@ def create_feature_extractor(
...
@@ -397,6 +421,14 @@ def create_feature_extractor(
>>> 'autowrap_functions': [leaf_function]})
>>> 'autowrap_functions': [leaf_function]})
"""
"""
if
tracer_kwargs
is
None
:
tracer_kwargs
=
{
"autowrap_modules"
:
(
math
,
torchvision
.
ops
,
),
"leaf_modules"
:
_get_leaf_modules_for_ops
(),
}
is_training
=
model
.
training
is_training
=
model
.
training
assert
any
(
assert
any
(
...
...
torchvision/ops/poolers.py
View file @
a7a36756
import
warnings
from
typing
import
Optional
,
List
,
Dict
,
Tuple
,
Union
from
typing
import
Optional
,
List
,
Dict
,
Tuple
,
Union
import
torch
import
torch
import
torch.fx
import
torchvision
import
torchvision
from
torch
import
nn
,
Tensor
from
torch
import
nn
,
Tensor
from
torchvision.ops.boxes
import
box_area
from
torchvision.ops.boxes
import
box_area
...
@@ -106,6 +108,126 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
...
@@ -106,6 +108,126 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
return
possible_scales
[
0
]
return
possible_scales
[
0
]
@
torch
.
fx
.
wrap
def
_setup_scales
(
features
:
List
[
Tensor
],
image_shapes
:
List
[
Tuple
[
int
,
int
]],
canonical_scale
:
int
,
canonical_level
:
int
)
->
Tuple
[
List
[
float
],
LevelMapper
]:
assert
len
(
image_shapes
)
!=
0
max_x
=
0
max_y
=
0
for
shape
in
image_shapes
:
max_x
=
max
(
shape
[
0
],
max_x
)
max_y
=
max
(
shape
[
1
],
max_y
)
original_input_shape
=
(
max_x
,
max_y
)
scales
=
[
_infer_scale
(
feat
,
original_input_shape
)
for
feat
in
features
]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min
=
-
torch
.
log2
(
torch
.
tensor
(
scales
[
0
],
dtype
=
torch
.
float32
)).
item
()
lvl_max
=
-
torch
.
log2
(
torch
.
tensor
(
scales
[
-
1
],
dtype
=
torch
.
float32
)).
item
()
map_levels
=
initLevelMapper
(
int
(
lvl_min
),
int
(
lvl_max
),
canonical_scale
=
canonical_scale
,
canonical_level
=
canonical_level
,
)
return
scales
,
map_levels
@
torch
.
fx
.
wrap
def
_filter_input
(
x
:
Dict
[
str
,
Tensor
],
featmap_names
:
List
[
str
])
->
List
[
Tensor
]:
x_filtered
=
[]
for
k
,
v
in
x
.
items
():
if
k
in
featmap_names
:
x_filtered
.
append
(
v
)
return
x_filtered
@
torch
.
fx
.
wrap
def
_multiscale_roi_align
(
x_filtered
:
List
[
Tensor
],
boxes
:
List
[
Tensor
],
output_size
:
List
[
int
],
sampling_ratio
:
int
,
scales
:
Optional
[
List
[
float
]],
mapper
:
Optional
[
LevelMapper
],
)
->
Tensor
:
"""
Args:
x_filtered (List[Tensor]): List of input tensors.
boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
(x1, y1, x2, y2) format and in the image reference size, not the feature map
reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
sampling_ratio (int): sampling ratio for ROIAlign
scales (Optional[List[float]]): If None, scales will be automatically infered. Default value is None.
mapper (Optional[LevelMapper]): If none, mapper will be automatically infered. Default value is None.
Returns:
result (Tensor)
"""
assert
scales
is
not
None
assert
mapper
is
not
None
num_levels
=
len
(
x_filtered
)
rois
=
_convert_to_roi_format
(
boxes
)
if
num_levels
==
1
:
return
roi_align
(
x_filtered
[
0
],
rois
,
output_size
=
output_size
,
spatial_scale
=
scales
[
0
],
sampling_ratio
=
sampling_ratio
,
)
levels
=
mapper
(
boxes
)
num_rois
=
len
(
rois
)
num_channels
=
x_filtered
[
0
].
shape
[
1
]
dtype
,
device
=
x_filtered
[
0
].
dtype
,
x_filtered
[
0
].
device
result
=
torch
.
zeros
(
(
num_rois
,
num_channels
,
)
+
output_size
,
dtype
=
dtype
,
device
=
device
,
)
tracing_results
=
[]
for
level
,
(
per_level_feature
,
scale
)
in
enumerate
(
zip
(
x_filtered
,
scales
)):
idx_in_level
=
torch
.
where
(
levels
==
level
)[
0
]
rois_per_level
=
rois
[
idx_in_level
]
result_idx_in_level
=
roi_align
(
per_level_feature
,
rois_per_level
,
output_size
=
output_size
,
spatial_scale
=
scale
,
sampling_ratio
=
sampling_ratio
,
)
if
torchvision
.
_is_tracing
():
tracing_results
.
append
(
result_idx_in_level
.
to
(
dtype
))
else
:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result
[
idx_in_level
]
=
result_idx_in_level
.
to
(
result
.
dtype
)
if
torchvision
.
_is_tracing
():
result
=
_onnx_merge_levels
(
levels
,
tracing_results
)
return
result
class
MultiScaleRoIAlign
(
nn
.
Module
):
class
MultiScaleRoIAlign
(
nn
.
Module
):
"""
"""
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
...
@@ -165,31 +287,24 @@ class MultiScaleRoIAlign(nn.Module):
...
@@ -165,31 +287,24 @@ class MultiScaleRoIAlign(nn.Module):
self
.
canonical_scale
=
canonical_scale
self
.
canonical_scale
=
canonical_scale
self
.
canonical_level
=
canonical_level
self
.
canonical_level
=
canonical_level
def
setup_scales
(
def
convert_to_roi_format
(
self
,
boxes
:
List
[
Tensor
])
->
Tensor
:
# TODO: deprecate eventually
warnings
.
warn
(
"`convert_to_roi_format` will no loger be public in future releases."
,
FutureWarning
)
return
_convert_to_roi_format
(
boxes
)
def
infer_scale
(
self
,
feature
:
Tensor
,
original_size
:
List
[
int
])
->
float
:
# TODO: deprecate eventually
warnings
.
warn
(
"`infer_scale` will no loger be public in future releases."
,
FutureWarning
)
return
_infer_scale
(
feature
,
original_size
)
def
setup_setup_scales
(
self
,
self
,
features
:
List
[
Tensor
],
features
:
List
[
Tensor
],
image_shapes
:
List
[
Tuple
[
int
,
int
]],
image_shapes
:
List
[
Tuple
[
int
,
int
]],
)
->
None
:
)
->
None
:
assert
len
(
image_shapes
)
!=
0
# TODO: deprecate eventually
max_x
=
0
warnings
.
warn
(
"`setup_setup_scales` will no loger be public in future releases."
,
FutureWarning
)
max_y
=
0
self
.
scales
,
self
.
map_levels
=
_setup_scales
(
features
,
image_shapes
,
self
.
canonical_scale
,
self
.
canonical_level
)
for
shape
in
image_shapes
:
max_x
=
max
(
shape
[
0
],
max_x
)
max_y
=
max
(
shape
[
1
],
max_y
)
original_input_shape
=
(
max_x
,
max_y
)
scales
=
[
_infer_scale
(
feat
,
original_input_shape
)
for
feat
in
features
]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min
=
-
torch
.
log2
(
torch
.
tensor
(
scales
[
0
],
dtype
=
torch
.
float32
)).
item
()
lvl_max
=
-
torch
.
log2
(
torch
.
tensor
(
scales
[
-
1
],
dtype
=
torch
.
float32
)).
item
()
self
.
scales
=
scales
self
.
map_levels
=
initLevelMapper
(
int
(
lvl_min
),
int
(
lvl_max
),
canonical_scale
=
self
.
canonical_scale
,
canonical_level
=
self
.
canonical_level
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -210,76 +325,21 @@ class MultiScaleRoIAlign(nn.Module):
...
@@ -210,76 +325,21 @@ class MultiScaleRoIAlign(nn.Module):
Returns:
Returns:
result (Tensor)
result (Tensor)
"""
"""
x_filtered
=
[]
x_filtered
=
_filter_input
(
x
,
self
.
featmap_names
)
for
k
,
v
in
x
.
items
():
if
self
.
scales
is
None
or
self
.
map_levels
is
None
:
if
k
in
self
.
featmap_names
:
self
.
scales
,
self
.
map_levels
=
_setup_scales
(
x_filtered
.
append
(
v
)
x_filtered
,
image_shapes
,
self
.
canonical_scale
,
self
.
canonical_level
num_levels
=
len
(
x_filtered
)
rois
=
_convert_to_roi_format
(
boxes
)
if
self
.
scales
is
None
:
self
.
setup_scales
(
x_filtered
,
image_shapes
)
scales
=
self
.
scales
assert
scales
is
not
None
if
num_levels
==
1
:
return
roi_align
(
x_filtered
[
0
],
rois
,
output_size
=
self
.
output_size
,
spatial_scale
=
scales
[
0
],
sampling_ratio
=
self
.
sampling_ratio
,
)
)
mapper
=
self
.
map_levels
return
_multiscale_roi_align
(
assert
mapper
is
not
None
x_filtered
,
boxes
,
levels
=
mapper
(
boxes
)
self
.
output_size
,
self
.
sampling_ratio
,
num_rois
=
len
(
rois
)
self
.
scales
,
num_channels
=
x_filtered
[
0
].
shape
[
1
]
self
.
map_levels
,
dtype
,
device
=
x_filtered
[
0
].
dtype
,
x_filtered
[
0
].
device
result
=
torch
.
zeros
(
(
num_rois
,
num_channels
,
)
+
self
.
output_size
,
dtype
=
dtype
,
device
=
device
,
)
)
tracing_results
=
[]
for
level
,
(
per_level_feature
,
scale
)
in
enumerate
(
zip
(
x_filtered
,
scales
)):
idx_in_level
=
torch
.
where
(
levels
==
level
)[
0
]
rois_per_level
=
rois
[
idx_in_level
]
result_idx_in_level
=
roi_align
(
per_level_feature
,
rois_per_level
,
output_size
=
self
.
output_size
,
spatial_scale
=
scale
,
sampling_ratio
=
self
.
sampling_ratio
,
)
if
torchvision
.
_is_tracing
():
tracing_results
.
append
(
result_idx_in_level
.
to
(
dtype
))
else
:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result
[
idx_in_level
]
=
result_idx_in_level
.
to
(
result
.
dtype
)
if
torchvision
.
_is_tracing
():
result
=
_onnx_merge_levels
(
levels
,
tracing_results
)
return
result
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
(
f
"
{
self
.
__class__
.
__name__
}
(featmap_names=
{
self
.
featmap_names
}
, "
f
"
{
self
.
__class__
.
__name__
}
(featmap_names=
{
self
.
featmap_names
}
, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment