Unverified Commit bd5f1e39 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Fix typing jit issue on RoIPool and RoIAlign (#6397)

* Fix typing jit issue on RoIPool and RoIAlign

* Fix nit.

* Address code review comments.
parent cc0a8385
......@@ -69,6 +69,15 @@ class DropBlockWrapper(nn.Module):
self.layer(a)
class PoolWrapper(nn.Module):
def __init__(self, pool: nn.Module):
super().__init__()
self.pool = pool
def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
return self.pool(imgs, boxes)
class RoIOpTester(ABC):
dtype = torch.float64
......@@ -150,6 +159,14 @@ class RoIOpTester(ABC):
boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype)
ops.roi_pool(a, [boxes], output_size=(2, 2))
def _helper_jit_boxes_list(self, model):
x = torch.rand(2, 1, 10, 10)
roi = torch.tensor([[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], dtype=torch.float).t()
rois = [roi, roi]
scriped = torch.jit.script(model)
y = scriped(x, rois)
assert y.shape == (10, 1, 3, 3)
@abstractmethod
def fn(*args, **kwargs):
pass
......@@ -210,6 +227,10 @@ class TestRoiPool(RoIOpTester):
def test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_pool)
def test_jit_boxes_list(self):
model = PoolWrapper(ops.RoIPool(output_size=[3, 3], spatial_scale=1.0))
self._helper_jit_boxes_list(model)
class TestPSRoIPool(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
......@@ -450,6 +471,10 @@ class TestRoIAlign(RoIOpTester):
with pytest.raises(RuntimeError, match="Only one image per batch is allowed"):
ops.roi_align(qx, qrois, output_size=5)
def test_jit_boxes_list(self):
model = PoolWrapper(ops.RoIAlign(output_size=[3, 3], spatial_scale=1.0, sampling_ratio=-1))
self._helper_jit_boxes_list(model)
class TestPSRoIAlign(RoIOpTester):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
......
......@@ -82,7 +82,7 @@ class RoIAlign(nn.Module):
self.sampling_ratio = sampling_ratio
self.aligned = aligned
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self) -> str:
......
......@@ -62,7 +62,7 @@ class RoIPool(nn.Module):
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input: Tensor, rois: Tensor) -> Tensor:
def forward(self, input: Tensor, rois: Union[Tensor, List[Tensor]]) -> Tensor:
return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment