Unverified Commit a7a36756 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Feature extraction default arguments - ops (#4810)

making torchvision ops leaf nodes by default
parent 39cf02a6
......@@ -7,12 +7,54 @@ from typing import Tuple
import numpy as np
import pytest
import torch
import torch.fx
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image
from torch import nn, Tensor
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import models, ops
from torchvision.models.feature_extraction import get_graph_node_names
class RoIOpTesterModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 2
def forward(self, a, b):
self.layer(a, b)
class MultiScaleRoIAlignModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 3
def forward(self, a, b, c):
self.layer(a, b, c)
class DeformConvModuleWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 3
def forward(self, a, b, c):
self.layer(a, b, c)
class StochasticDepthWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 1
def forward(self, a):
self.layer(a)
class RoIOpTester(ABC):
......@@ -46,6 +88,15 @@ class RoIOpTester(ABC):
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
......@@ -91,6 +142,10 @@ class RoIOpTester(ABC):
def fn(*args, **kwargs):
pass
@abstractmethod
def make_obj(*args, **kwargs):
pass
@abstractmethod
def get_script_fn(*args, **kwargs):
pass
......@@ -104,6 +159,10 @@ class TestRoiPool(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
obj = ops.RoIPool((pool_h, pool_w), spatial_scale)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.roi_pool)
return lambda x: scriped(x, rois, pool_size)
......@@ -144,6 +203,10 @@ class TestPSRoIPool(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, wrap=False):
obj = ops.PSRoIPool((pool_h, pool_w), spatial_scale)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.ps_roi_pool)
return lambda x: scriped(x, rois, pool_size)
......@@ -223,6 +286,12 @@ class TestRoIAlign(RoIOpTester):
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, aligned=False, wrap=False):
obj = ops.RoIAlign(
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.roi_align)
return lambda x: scriped(x, rois, pool_size)
......@@ -374,6 +443,10 @@ class TestPSRoIAlign(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
def make_obj(self, pool_h=5, pool_w=5, spatial_scale=1, sampling_ratio=-1, wrap=False):
obj = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)
return RoIOpTesterModuleWrapper(obj) if wrap else obj
def get_script_fn(self, rois, pool_size):
scriped = torch.jit.script(ops.ps_roi_align)
return lambda x: scriped(x, rois, pool_size)
......@@ -422,12 +495,18 @@ class TestPSRoIAlign(RoIOpTester):
class TestMultiScaleRoIAlign:
def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
if fmap_names is None:
fmap_names = ["0"]
obj = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
return MultiScaleRoIAlignModuleWrapper(obj) if wrap else obj
def test_msroialign_repr(self):
fmap_names = ["0"]
output_size = (7, 7)
sampling_ratio = 2
# Pass mock feature map names
t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio)
t = self.make_obj(fmap_names, output_size, sampling_ratio, wrap=False)
# Check integrity of object __repr__ attribute
expected_string = (
......@@ -436,6 +515,15 @@ class TestMultiScaleRoIAlign:
)
assert repr(t) == expected_string
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold):
......@@ -693,6 +781,21 @@ class TestDeformConv:
return x, weight, offset, mask, bias, stride, pad, dilation
def make_obj(self, in_channels=6, out_channels=2, kernel_size=(3, 2), groups=2, wrap=False):
obj = ops.DeformConv2d(
in_channels, out_channels, kernel_size, stride=(2, 1), padding=(1, 0), dilation=(2, 1), groups=groups
)
return DeformConvModuleWrapper(obj) if wrap else obj
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_is_leaf_node(self, device):
op_obj = self.make_obj(wrap=True).to(device=device)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
......@@ -705,9 +808,9 @@ class TestDeformConv:
groups = 2
tol = 2e-3 if dtype is torch.half else 1e-5
layer = ops.DeformConv2d(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
).to(device=x.device, dtype=dtype)
layer = self.make_obj(in_channels, out_channels, kernel_size, groups, wrap=False).to(
device=x.device, dtype=dtype
)
res = layer(x, offset, mask)
weight = layer.weight.data
......@@ -1200,6 +1303,20 @@ class TestStochasticDepth:
elif p == 1:
assert out.equal(torch.zeros_like(x))
def make_obj(self, p, mode, wrap=False):
obj = ops.StochasticDepth(p, mode)
return StochasticDepthWrapper(obj) if wrap else obj
@pytest.mark.parametrize("p", (0, 1))
@pytest.mark.parametrize("mode", ["batch", "row"])
def test_is_leaf_node(self, p, mode):
op_obj = self.make_obj(p, mode, wrap=True)
graph_node_names = get_graph_node_names(op_obj)
assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
class TestUtils:
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
......
import inspect
import math
import re
import warnings
from collections import OrderedDict
from copy import deepcopy
from itertools import chain
from typing import Dict, Callable, List, Union, Optional, Tuple
from typing import Dict, Callable, List, Union, Optional, Tuple, Any
import torch
import torchvision
from torch import fx
from torch import nn
from torch.fx.graph_module import _copy_attr
......@@ -172,8 +175,19 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT
warnings.warn(msg + suggestion_msg)
def _get_leaf_modules_for_ops() -> List[type]:
members = inspect.getmembers(torchvision.ops)
result = []
for _, obj in members:
if inspect.isclass(obj) and issubclass(obj, torch.nn.Module):
result.append(obj)
return result
def get_graph_node_names(
model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False
model: nn.Module,
tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False,
) -> Tuple[List[str], List[str]]:
"""
Dev utility to return node names in order of execution. See note on node
......@@ -198,6 +212,7 @@ def get_graph_node_names(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (they are eventually passed onto
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
......@@ -211,6 +226,14 @@ def get_graph_node_names(
>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)
"""
if tracer_kwargs is None:
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
is_training = model.training
train_tracer = NodePathTracer(**tracer_kwargs)
train_tracer.trace(model.train())
......@@ -294,7 +317,7 @@ def create_feature_extractor(
return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
tracer_kwargs: Dict = {},
tracer_kwargs: Optional[Dict[str, Any]] = None,
suppress_diff_warning: bool = False,
) -> fx.GraphModule:
"""
......@@ -353,6 +376,7 @@ def create_feature_extractor(
tracer_kwargs (dict, optional): a dictionary of keywork arguments for
``NodePathTracer`` (which passes them onto it's parent class
`torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
By default it will be set to wrap and make leaf nodes all torchvision ops.
suppress_diff_warning (bool, optional): whether to suppress a warning
when there are discrepancies between the train and eval version of
the graph. Defaults to False.
......@@ -397,6 +421,14 @@ def create_feature_extractor(
>>> 'autowrap_functions': [leaf_function]})
"""
if tracer_kwargs is None:
tracer_kwargs = {
"autowrap_modules": (
math,
torchvision.ops,
),
"leaf_modules": _get_leaf_modules_for_ops(),
}
is_training = model.training
assert any(
......
import warnings
from typing import Optional, List, Dict, Tuple, Union
import torch
import torch.fx
import torchvision
from torch import nn, Tensor
from torchvision.ops.boxes import box_area
......@@ -106,6 +108,126 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float:
return possible_scales[0]
@torch.fx.wrap
def _setup_scales(
features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int
) -> Tuple[List[float], LevelMapper]:
assert len(image_shapes) != 0
max_x = 0
max_y = 0
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [_infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
map_levels = initLevelMapper(
int(lvl_min),
int(lvl_max),
canonical_scale=canonical_scale,
canonical_level=canonical_level,
)
return scales, map_levels
@torch.fx.wrap
def _filter_input(x: Dict[str, Tensor], featmap_names: List[str]) -> List[Tensor]:
x_filtered = []
for k, v in x.items():
if k in featmap_names:
x_filtered.append(v)
return x_filtered
@torch.fx.wrap
def _multiscale_roi_align(
x_filtered: List[Tensor],
boxes: List[Tensor],
output_size: List[int],
sampling_ratio: int,
scales: Optional[List[float]],
mapper: Optional[LevelMapper],
) -> Tensor:
"""
Args:
x_filtered (List[Tensor]): List of input tensors.
boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in
(x1, y1, x2, y2) format and in the image reference size, not the feature map
reference. The coordinate must satisfy ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
output_size (Union[List[Tuple[int, int]], List[int]]): size of the output
sampling_ratio (int): sampling ratio for ROIAlign
scales (Optional[List[float]]): If None, scales will be automatically infered. Default value is None.
mapper (Optional[LevelMapper]): If none, mapper will be automatically infered. Default value is None.
Returns:
result (Tensor)
"""
assert scales is not None
assert mapper is not None
num_levels = len(x_filtered)
rois = _convert_to_roi_format(boxes)
if num_levels == 1:
return roi_align(
x_filtered[0],
rois,
output_size=output_size,
spatial_scale=scales[0],
sampling_ratio=sampling_ratio,
)
levels = mapper(boxes)
num_rois = len(rois)
num_channels = x_filtered[0].shape[1]
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(
num_rois,
num_channels,
)
+ output_size,
dtype=dtype,
device=device,
)
tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.where(levels == level)[0]
rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align(
per_level_feature,
rois_per_level,
output_size=output_size,
spatial_scale=scale,
sampling_ratio=sampling_ratio,
)
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
else:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result[idx_in_level] = result_idx_in_level.to(result.dtype)
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, tracing_results)
return result
class MultiScaleRoIAlign(nn.Module):
"""
Multi-scale RoIAlign pooling, which is useful for detection with or without FPN.
......@@ -165,31 +287,24 @@ class MultiScaleRoIAlign(nn.Module):
self.canonical_scale = canonical_scale
self.canonical_level = canonical_level
def setup_scales(
def convert_to_roi_format(self, boxes: List[Tensor]) -> Tensor:
# TODO: deprecate eventually
warnings.warn("`convert_to_roi_format` will no loger be public in future releases.", FutureWarning)
return _convert_to_roi_format(boxes)
def infer_scale(self, feature: Tensor, original_size: List[int]) -> float:
# TODO: deprecate eventually
warnings.warn("`infer_scale` will no loger be public in future releases.", FutureWarning)
return _infer_scale(feature, original_size)
def setup_setup_scales(
self,
features: List[Tensor],
image_shapes: List[Tuple[int, int]],
) -> None:
assert len(image_shapes) != 0
max_x = 0
max_y = 0
for shape in image_shapes:
max_x = max(shape[0], max_x)
max_y = max(shape[1], max_y)
original_input_shape = (max_x, max_y)
scales = [_infer_scale(feat, original_input_shape) for feat in features]
# get the levels in the feature map by leveraging the fact that the network always
# downsamples by a factor of 2 at each level.
lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()
self.scales = scales
self.map_levels = initLevelMapper(
int(lvl_min),
int(lvl_max),
canonical_scale=self.canonical_scale,
canonical_level=self.canonical_level,
)
# TODO: deprecate eventually
warnings.warn("`setup_setup_scales` will no loger be public in future releases.", FutureWarning)
self.scales, self.map_levels = _setup_scales(features, image_shapes, self.canonical_scale, self.canonical_level)
def forward(
self,
......@@ -210,76 +325,21 @@ class MultiScaleRoIAlign(nn.Module):
Returns:
result (Tensor)
"""
x_filtered = []
for k, v in x.items():
if k in self.featmap_names:
x_filtered.append(v)
num_levels = len(x_filtered)
rois = _convert_to_roi_format(boxes)
if self.scales is None:
self.setup_scales(x_filtered, image_shapes)
scales = self.scales
assert scales is not None
if num_levels == 1:
return roi_align(
x_filtered[0],
rois,
output_size=self.output_size,
spatial_scale=scales[0],
sampling_ratio=self.sampling_ratio,
x_filtered = _filter_input(x, self.featmap_names)
if self.scales is None or self.map_levels is None:
self.scales, self.map_levels = _setup_scales(
x_filtered, image_shapes, self.canonical_scale, self.canonical_level
)
mapper = self.map_levels
assert mapper is not None
levels = mapper(boxes)
num_rois = len(rois)
num_channels = x_filtered[0].shape[1]
dtype, device = x_filtered[0].dtype, x_filtered[0].device
result = torch.zeros(
(
num_rois,
num_channels,
)
+ self.output_size,
dtype=dtype,
device=device,
return _multiscale_roi_align(
x_filtered,
boxes,
self.output_size,
self.sampling_ratio,
self.scales,
self.map_levels,
)
tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.where(levels == level)[0]
rois_per_level = rois[idx_in_level]
result_idx_in_level = roi_align(
per_level_feature,
rois_per_level,
output_size=self.output_size,
spatial_scale=scale,
sampling_ratio=self.sampling_ratio,
)
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
else:
# result and result_idx_in_level's dtypes are based on dtypes of different
# elements in x_filtered. x_filtered contains tensors output by different
# layers. When autocast is active, it may choose different dtypes for
# different layers' outputs. Therefore, we defensively match result's dtype
# before copying elements from result_idx_in_level in the following op.
# We need to cast manually (can't rely on autocast to cast for us) because
# the op acts on result in-place, and autocast only affects out-of-place ops.
result[idx_in_level] = result_idx_in_level.to(result.dtype)
if torchvision._is_tracing():
result = _onnx_merge_levels(levels, tracing_results)
return result
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(featmap_names={self.featmap_names}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment