Unverified Commit a7a36756 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Feature extraction default arguments - ops (#4810)

making torchvision ops leaf nodes by default
parent 39cf02a6
...@@ -7,12 +7,54 @@ from typing import Tuple ...@@ -7,12 +7,54 @@ from typing import Tuple
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.fx
from common_utils import needs_cuda, cpu_and_gpu, assert_equal from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image from PIL import Image
from torch import nn, Tensor from torch import nn, Tensor
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import models, ops from torchvision import models, ops
from torchvision.models.feature_extraction import get_graph_node_names
class RoIOpTesterModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 2
def forward(self, a, b):
self.layer(a, b)
class MultiScaleRoIAlignModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 3
def forward(self, a, b, c):
self.layer(a, b, c)
class DeformConvModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 3
def forward(self, a, b, c):
self.layer(a, b, c)
class StochasticDepthWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 1
def forward(self, a):
self.layer(a)
class RoIOpTester(ABC): class RoIOpTester(ABC):
...@@ -46,6 +88,15 @@ class RoIOpTester(ABC): ...@@ -46,6 +88,15 @@ class RoIOpTester(ABC):
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
...@@ -91,6 +142,10 @@ class RoIOpTester(ABC): ...@@ -91,6 +142,10 @@ class RoIOpTester(ABC):
def fn(*args, **kwargs): def fn(*args, **kwargs):
pass pass
@abstractmethod
def make_obj(*args, **kwargs):
pass
@abstractmethod @abstractmethod
def get_script_fn(*args, **kwargs): def get_script_fn(*args, **kwargs):
pass pass
...@@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester): ...@@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois) return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size): def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.roi_pool) scriped = torch.jit.script(ops.roi_pool)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
...@@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester): ...@@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size): def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.ps_roi_pool) scriped = torch.jit.script(ops.ps_roi_pool)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
...@@ -223,6 +286,12 @@ class TestRoIAlign(RoIOpTester): ...@@ -223,6 +286,12 @@ class TestRoIAlign(RoIOpTester):
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
)(x, rois) )(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
obj = ops.RoIAlign(
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size): def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.roi_align) scriped = torch.jit.script(ops.roi_align)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
...@@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester): ...@@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size): def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.ps_roi_align) scriped = torch.jit.script(ops.ps_roi_align)
return lambda x: scriped(x, rois, pool_size) return lambda x: scriped(x, rois, pool_size)
...@@ -422,12 +495,18 @@ class TestPSRoIAlign(RoIOpTester): ...@@ -422,12 +495,18 @@ class TestPSRoIAlign(RoIOpTester):
class TestMultiScaleRoIAlign: class TestMultiScaleRoIAlign:
def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
if fmap_names is None:
fmap_names = ["0"]
obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj
def test_msroialign_repr(self): def test_msroialign_repr(self):
fmap_names = ["0"] fmap_names = ["0"]
output_size = (7, 7) output_size = (7, 7)
sampling_ratio = 2 sampling_ratio = 2
# Pass mock feature map names # Pass mock feature map names
t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
# Check integrity of object __repr__ attribute # Check integrity of object __repr__ attribute
expected_string = ( expected_string = (
...@@ -436,6 +515,15 @@ class TestMultiScaleRoIAlign: ...@@ -436,6 +515,15 @@ class TestMultiScaleRoIAlign:
) )
assert repr(t) == expected_string assert repr(t) == expected_string
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
class TestNMS: class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold): def _reference_nms(self, boxes, scores, iou_threshold):
...@@ -693,6 +781,21 @@ class TestDeformConv: ...@@ -693,6 +781,21 @@ class TestDeformConv:
return x, weight, offset, mask, bias, stride, pad, dilation return x, weight, offset, mask, bias, stride, pad, dilation
def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
obj = ops.DeformConv2d(
in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
)
return DeformConvModuleWrapper(obj) if wrap else obj
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("batch_sz", (0, 33))
...@@ -705,9 +808,9 @@ class TestDeformConv: ...@@ -705,9 +808,9 @@ class TestDeformConv:
groups = 2 groups = 2
tol = 2e-3 if dtype is torch.half else 1e-5 tol = 2e-3 if dtype is torch.half else 1e-5
layer = ops.DeformConv2d( layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups device=x.device, dtype=dtype
).to(device=x.device, dtype=dtype) )
res = layer(x, offset, mask) res = layer(x, offset, mask)
weight = layer.weight.data weight = layer.weight.data
...@@ -1200,6 +1303,20 @@ class TestStochasticDepth: ...@@ -1200,6 +1303,20 @@ class TestStochasticDepth:
elif p == 1: elif p == 1:
assert out.equal(torch.zeros_like(x)) assert out.equal(torch.zeros_like(x))
def make_obj(self, p, mode, wrap=False):
obj = ops.StochasticDepth(p, mode)
return StochasticDepthWrapper(obj) if wrap else obj
@pytest.mark.parametrize("p", (0, 1))
@pytest.mark.parametrize("mode", ["batch", "row"])
def test_is_leaf_node(self, p, mode):
op_obj = self.make_obj(p, mode, wrap=True)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
class TestUtils: class TestUtils:
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm]) @pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
......
import inspect
import math
import re import re
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from typing import Dict, Callable, List, Union, Optional, Tuple from typing import Dict, Callable, List, Union, Optional, Tuple, Any
import torch import torch
import torchvision
from torch import fx from torch import fx
from torch import nn from torch import nn
from torch.fx.graph_module import _copy_attr from torch.fx.graph_module import _copy_attr
...@@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT ...@@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT
warnings.warn(msg + suggestion_msg) warnings.warn(msg + suggestion_msg)
def _get_leaf_modules_for_ops() -> List[type]:
members = inspect.getmembers(torchvision.ops)
result = []
for _, obj in members:
if inspect.isclass(obj) and issubclass(obj, torch.nn.Module):
result.append(obj)
return result
def get_graph_node_names( def get_graph_node_names(
model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False,
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
""" """
Dev utility to return node names in order of execution. See note on node Dev utility to return node names in order of execution. See note on node
...@@ -198,6 +212,7 @@ def get_graph_node_names( ...@@ -198,6 +212,7 @@ def get_graph_node_names(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (they are eventually passed onto ``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
...@@ -211,6 +226,14 @@ def get_graph_node_names( ...@@ -211,6 +226,14 @@ def get_graph_node_names(
>>> model = torchvision.models.resnet18() >>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model) >>> train_nodes, eval_nodes = get_graph_node_names(model)
""" """
if tracer_kwargs is None:
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
is_training = model.training is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs) train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train()) train_tracer.trace(model.train())
...@@ -294,7 +317,7 @@ def create_feature_extractor( ...@@ -294,7 +317,7 @@ def create_feature_extractor(
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {}, tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False, suppress_diff_warning: bool = False,
) -> fx.GraphModule: ) -> fx.GraphModule:
""" """
...@@ -353,6 +376,7 @@ def create_feature_extractor( ...@@ -353,6 +376,7 @@ def create_feature_extractor(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (which passes them onto it's parent class ``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_). `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
suppress_diff_warning (bool, optional): whether to suppress a warning suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of when there are discrepancies between the train and eval version of
the graph. Defaults to False. the graph. Defaults to False.
...@@ -397,6 +421,14 @@ def create_feature_extractor( ...@@ -397,6 +421,14 @@ def create_feature_extractor(
>>> 'autowrap_functions': [leaf_function]}) >>> 'autowrap_functions': [leaf_function]})
""" """
if tracer_kwargs is None:
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
is_training = model.training is_training = model.training
assert any( assert any(
......
import warnings
from typing import Optional, List, Dict, Tuple, Union from typing import Optional, List, Dict, Tuple, Union
import torch import torch
import torch.fx
import torchvision import torchvision
from torch import nn, Tensor from torch import nn, Tensor
from torchvision.ops.boxes import box_area from torchvision.ops.boxes import box_area
...@@ -106,6 +108,126 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float: ...@@ -106,6 +108,126 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
return possible_scales[0] return possible_scales[0]
@torch.fx.wrap
def _setup_scales(
features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int
) -> Tuple[List[float], LevelMapper]:
assert len(image_shapes) != 0
max_x = 0
max_y = 0
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [_infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
map_levels = initLevelMapper(
int(lvl_min),
int(lvl_max),
canonical_scale=canonical_scale,
canonical_level=canonical_level,
)
return scales, map_levels
@torch.fx.wrap
def _filter_input(x: Dict[str, Tensor], featmap_names: List[str]) -> List[Tensor]:
x_filtered = []
for k, v in x.items():
if k in featmap_names:
x_filtered.append(v)
return x_filtered
@torch.fx.wrap
def _multiscale_roi_align(
x_filtered: List[Tensor],
boxes: List[Tensor],
output_size: List[int],
sampling_ratio: int,
scales: Optional[List[float]],
mapper: Optional[LevelMapper],
) -> Tensor:
"""
Args:
x_filtered (List[Tensor]): List of input tensors.
boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
(x1, y1, x2, y2) format and in the image reference size, not the feature map
reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
sampling_ratio (int): sampling ratio for ROIAlign
scales (Optional[List[float]]): If None, scales will be automatically infered. Default value is None.
mapper (Optional[LevelMapper]): If none, mapper will be automatically infered. Default value is None.
Returns:
result (Tensor)
"""
assert scales is not None
assert mapper is not None
num_levels = len(x_filtered)
rois = _convert_to_roi_format(boxes)
if num_levels == 1:
return roi_align(
x_filtered[0],
rois,
output_size=output_size,
spatial_scale=scales[0],
sampling_ratio=sampling_ratio,
)
levels = mapper(boxes)
num_rois = len(rois)
num_channels = x_filtered[0].shape[1]
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(
num_rois,
num_channels,
)
+ output_size,
dtype=dtype,
device=device,
)
tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.where(levels == level)[0]
rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align(
per_level_feature,
rois_per_level,
output_size=output_size,
spatial_scale=scale,
sampling_ratio=sampling_ratio,
)
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
else:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result[idx_in_level] = result_idx_in_level.to(result.dtype)
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, tracing_results)
return result
class MultiScaleRoIAlign(nn.Module): class MultiScaleRoIAlign(nn.Module):
""" """
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN. Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
...@@ -165,31 +287,24 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -165,31 +287,24 @@ class MultiScaleRoIAlign(nn.Module):
self.canonical_scale = canonical_scale self.canonical_scale = canonical_scale
self.canonical_level = canonical_level self.canonical_level = canonical_level
def setup_scales( def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
# TODO: deprecate eventually
warnings.warn("`convert_to_roi_format` will no loger be public in future releases.", FutureWarning)
return _convert_to_roi_format(boxes)
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# TODO: deprecate eventually
warnings.warn("`infer_scale` will no loger be public in future releases.", FutureWarning)
return _infer_scale(feature, original_size)
def setup_setup_scales(
self, self,
features: List[Tensor], features: List[Tensor],
image_shapes: List[Tuple[int, int]], image_shapes: List[Tuple[int, int]],
) -> None: ) -> None:
assert len(image_shapes) != 0 # TODO: deprecate eventually
max_x = 0 warnings.warn("`setup_setup_scales` will no loger be public in future releases.", FutureWarning)
max_y = 0 self.scales, self.map_levels = _setup_scales(features, image_shapes, self.canonical_scale, self.canonical_level)
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [_infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.scales = scales
self.map_levels = initLevelMapper(
int(lvl_min),
int(lvl_max),
canonical_scale=self.canonical_scale,
canonical_level=self.canonical_level,
)
def forward( def forward(
self, self,
...@@ -210,75 +325,20 @@ class MultiScaleRoIAlign(nn.Module): ...@@ -210,75 +325,20 @@ class MultiScaleRoIAlign(nn.Module):
Returns: Returns:
result (Tensor) result (Tensor)
""" """
x_filtered = [] x_filtered = _filter_input(x, self.featmap_names)
for k, v in x.items(): if self.scales is None or self.map_levels is None:
if k in self.featmap_names: self.scales, self.map_levels = _setup_scales(
x_filtered.append(v) x_filtered, image_shapes, self.canonical_scale, self.canonical_level
num_levels = len(x_filtered)
rois = _convert_to_roi_format(boxes)
if self.scales is None:
self.setup_scales(x_filtered, image_shapes)
scales = self.scales
assert scales is not None
if num_levels == 1:
return roi_align(
x_filtered[0],
rois,
output_size=self.output_size,
spatial_scale=scales[0],
sampling_ratio=self.sampling_ratio,
) )
mapper = self.map_levels return _multiscale_roi_align(
assert mapper is not None x_filtered,
boxes,
levels = mapper(boxes) self.output_size,
self.sampling_ratio,
num_rois = len(rois) self.scales,
num_channels = x_filtered[0].shape[1] self.map_levels,
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(
num_rois,
num_channels,
) )
+ self.output_size,
dtype=dtype,
device=device,
)
tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.where(levels == level)[0]
rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align(
per_level_feature,
rois_per_level,
output_size=self.output_size,
spatial_scale=scale,
sampling_ratio=self.sampling_ratio,
)
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
else:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result[idx_in_level] = result_idx_in_level.to(result.dtype)
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, tracing_results)
return result
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment