Unverified Commit 3428a7de authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Added test for aligned=True (#3540)

parent 01398088
......@@ -54,7 +54,7 @@ class OpTester(object):
class RoIOpTester(OpTester):
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
pool_size = 5
......@@ -70,11 +70,11 @@ class RoIOpTester(OpTester):
dtype=rois_dtype, device=device)
pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
# the following should be true whether we're running an autocast test or not.
self.assertTrue(y.dtype == x.dtype)
gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
sampling_ratio=-1, device=device, dtype=self.dtype)
sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs)
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))
......@@ -304,6 +304,10 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
for aligned in (True, False):
super()._test_forward(device, contiguous, x_dtype, rois_dtype, aligned=aligned)
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment