Unverified Commit 7992eb5d authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

simplify _get_script_fn (#3541)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 3428a7de
......@@ -135,11 +135,8 @@ class RoIPoolTester(RoIOpTester, unittest.TestCase):
return ops.RoIPool((pool_h, pool_w), spatial_scale)(x, rois)
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
scriped = torch.jit.script(ops.roi_pool)
return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
device=None, dtype=torch.float64):
......@@ -177,11 +174,8 @@ class PSRoIPoolTester(RoIOpTester, unittest.TestCase):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_pool(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
scriped = torch.jit.script(ops.ps_roi_pool)
return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1,
device=None, dtype=torch.float64):
......@@ -257,11 +251,8 @@ class RoIAlignTester(RoIOpTester, unittest.TestCase):
sampling_ratio=sampling_ratio, aligned=aligned)(x, rois)
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (Tensor, Tensor, int) -> Tensor
return ops.roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
scriped = torch.jit.script(ops.roi_align)
return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False,
device=None, dtype=torch.float64):
......@@ -315,11 +306,8 @@ class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
sampling_ratio=sampling_ratio)(x, rois)
def get_script_fn(self, rois, pool_size):
@torch.jit.script
def script_fn(input, rois, pool_size):
# type: (Tensor, Tensor, int) -> Tensor
return ops.ps_roi_align(input, rois, pool_size, 1.0)[0]
return lambda x: script_fn(x, rois, pool_size)
scriped = torch.jit.script(ops.ps_roi_align)
return lambda x: scriped(x, rois, pool_size)
def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
sampling_ratio=-1, dtype=torch.float64):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment