Unverified Commit f56e6f63 authored by inisis's avatar inisis Committed by GitHub
Browse files

update roipool to make it torch fx traceable (#6501)



* update roipool to make in torch fx traceable

* update ps_roi_align, ps_roi_pool and roi_align to make them torch fx traceable

* add unittest

* Remove `.to(output_fx)` from test
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 56fb8411
...@@ -118,6 +118,25 @@ class RoIOpTester(ABC): ...@@ -118,6 +118,25 @@ class RoIOpTester(ABC):
assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
op_obj = self.make_obj().to(device=device)
graph_module = torch.fx.symbolic_trace(op_obj)
pool_size = 5
n_channels = 2 * (pool_size**2)
x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
rois = torch.tensor(
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
dtype=rois_dtype,
device=device,
)
output_gt = op_obj(x, rois)
assert output_gt.dtype == x.dtype
output_fx = graph_module(x, rois)
assert output_fx.dtype == x.dtype
tol = 1e-5
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
......
import torch import torch
import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
...@@ -7,6 +8,7 @@ from ..utils import _log_api_usage_once ...@@ -7,6 +8,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def ps_roi_align( def ps_roi_align(
input: Tensor, input: Tensor,
boxes: Tensor, boxes: Tensor,
......
import torch import torch
import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops
...@@ -7,6 +8,7 @@ from ..utils import _log_api_usage_once ...@@ -7,6 +8,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def ps_roi_pool( def ps_roi_pool(
input: Tensor, input: Tensor,
boxes: Tensor, boxes: Tensor,
......
from typing import List, Union from typing import List, Union
import torch import torch
import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -10,6 +11,7 @@ from ..utils import _log_api_usage_once ...@@ -10,6 +11,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def roi_align( def roi_align(
input: Tensor, input: Tensor,
boxes: Union[Tensor, List[Tensor]], boxes: Union[Tensor, List[Tensor]],
......
from typing import List, Union from typing import List, Union
import torch import torch
import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
...@@ -10,6 +11,7 @@ from ..utils import _log_api_usage_once ...@@ -10,6 +11,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def roi_pool( def roi_pool(
input: Tensor, input: Tensor,
boxes: Union[Tensor, List[Tensor]], boxes: Union[Tensor, List[Tensor]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment