Unverified Commit f56e6f63 authored by inisis's avatar inisis Committed by GitHub
Browse files

update roipool to make it torch fx traceable (#6501)



* update roipool to make in torch fx traceable

* update ps_roi_align, ps_roi_pool and roi_align to make them torch fx traceable

* add unittest

* Remove `.to(output_fx)` from test
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 56fb8411
......@@ -118,6 +118,25 @@ class RoIOpTester(ABC):
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.float):
op_obj = self.make_obj().to(device=device)
graph_module = torch.fx.symbolic_trace(op_obj)
pool_size = 5
n_channels = 2 * (pool_size**2)
x = torch.rand(2, n_channels, 5, 5, dtype=x_dtype, device=device)
rois = torch.tensor(
[[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy)
dtype=rois_dtype,
device=device,
)
output_gt = op_obj(x, rois)
assert output_gt.dtype == x.dtype
output_fx = graph_module(x, rois)
assert output_fx.dtype == x.dtype
tol = 1e-5
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
......
import torch
import torch.fx
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
......@@ -7,6 +8,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def ps_roi_align(
input: Tensor,
boxes: Tensor,
......
import torch
import torch.fx
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops
......@@ -7,6 +8,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def ps_roi_pool(
input: Tensor,
boxes: Tensor,
......
from typing import List, Union
import torch
import torch.fx
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
......@@ -10,6 +11,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def roi_align(
input: Tensor,
boxes: Union[Tensor, List[Tensor]],
......
from typing import List, Union
import torch
import torch.fx
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
......@@ -10,6 +11,7 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
@torch.fx.wrap
def roi_pool(
input: Tensor,
boxes: Union[Tensor, List[Tensor]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment