Commit 896d7ec7 authored by Lukas Bommes's avatar Lukas Bommes Committed by Francisco Massa
Browse files

Implementation for Position-sensitive ROI Pool/Align [updated] (#1410)

* added PSRoiAlign and PSRoiPool with C++ autograd and torch ops

* fixed linter errors

* fixed linter errors 2

* fixed linter errors 3
parent 63128cb4
...@@ -487,6 +487,685 @@ class RoIAlignTester(unittest.TestCase): ...@@ -487,6 +487,685 @@ class RoIAlignTester(unittest.TestCase):
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA' assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA'
def bilinear_interpolate(data, height, width, y, x):
if y < -1.0 or y > height or x < -1.0 or x > width:
return 0.
if y <= 0:
y = 0.
if x <= 0:
x = 0.
y_low, x_low = int(y), int(x)
y_high, x_high = 0, 0
if y_low >= height - 1:
y_high = y_low = height - 1
y = float(y_low)
else:
y_high = y_low + 1
if x_low >= width - 1:
x_high = x_low = width - 1
x = float(x_low)
else:
x_high = x_low + 1
ly = y - y_low
lx = x - x_low
hy, hx = 1. - ly, 1. - lx
v1 = data[y_low * width + x_low]
v2 = data[y_low * width + x_high]
v3 = data[y_high * width + x_low]
v4 = data[y_high * width + x_high]
w1, w2, w3, w4 = hy * hx, hy * lx, ly * hx, ly * lx
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
class PSRoIAlignTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float64
def slow_ps_roi_align(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
sampling_ratio=-1, dtype=torch.float64):
if device is None:
device = torch.device("cpu")
num_input_channels = in_data.size(1)
assert num_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
num_output_channels = int(num_input_channels / (pool_h * pool_w))
out_data = torch.zeros(rois.size(0), num_output_channels, pool_h, pool_w, dtype=dtype, device=device)
for n in range(0, in_data.size(0)):
for r, roi in enumerate(rois):
if roi[0] != n:
continue
roi[1:] = (roi[1:] * spatial_scale) - 0.5
c_in = 0
roi_height = float(roi[4].item() - roi[2].item())
roi_width = float(roi[3].item() - roi[1].item())
bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)
for c_out in range(0, num_output_channels):
for j in range(0, pool_h):
start_h = float(j) * bin_h + roi[2].item()
for i in range(0, pool_w):
start_w = float(i) * bin_w + roi[1].item()
roi_bin_grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(roi_height / pool_h))
roi_bin_grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(roi_width / pool_w))
val = 0.
for iy in range(0, roi_bin_grid_h):
y = start_h + (iy + 0.5) * bin_h / float(roi_bin_grid_h)
for ix in range(0, roi_bin_grid_w):
x = start_w + (ix + 0.5) * bin_w / float(roi_bin_grid_w)
val += bilinear_interpolate(
in_data[n, c_in, :, :].flatten(),
in_data.size(-2),
in_data.size(-1),
y, x
)
count = roi_bin_grid_h * roi_bin_grid_w
out_data[r, c_out, j, i] = val / count
c_in += 1
return out_data
def test_ps_roi_align_basic_cpu(self):
device = torch.device('cpu')
pool_size = 3
x = torch.rand(1, 2 * (pool_size ** 2), 7, 7, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 5, 5]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2)
y = ps_roi_align(x, rois)
gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU'
y = ps_roi_align(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=-1,
dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU'
def test_ps_roi_align_cpu(self):
device = torch.device('cpu')
pool_size = 5
x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2)
y = ps_roi_align(x, rois)
gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU'
y = ps_roi_align(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w,
device, spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU'
def test_ps_roi_align_gradient_cpu(self):
device = torch.device('cpu')
pool_size = 3
layer = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1,
sampling_ratio=-1).to(dtype=self.dtype, device=device)
x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 4, 4],
[0, 0, 3, 5, 5],
[0, 1, 0, 2, 4]],
dtype=self.dtype, device=device)
y = layer(x, rois)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[8.125e-01, 6.875e-01, 0.0, 0.0, 0.0, ],
[2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ],
[1.0416666667e-01, 6.25e-02, 0.0, 0.0, 0.0, ],
[5.2083333333e-01, 3.125e-01, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ]],
[[8.3266726847e-17, 1.125e00, 3.750e-01, 0.0, 0.0, ],
[2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ],
[0.0, 3.4722222222e-02, 9.7222222222e-02, 3.4722222222e-02, 0.0, ],
[0.0, 1.7361111111e-01, 4.8611111111e-01, 1.7361111111e-01, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ]],
[[0.0, 5.000e-01, 4.375e-01, 5.000e-01, 6.25e-02, ],
[0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ],
[0.0, 0.0, 0.0, 6.25e-02, 1.0416666667e-01, ],
[0.0, 0.0, 0.0, 3.125e-01, 5.2083333333e-01, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ],
[5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ],
[3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ],
[3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ],
[5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ],
[0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ],
[0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ],
[0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ],
[0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ],
[0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ],
[2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ],
[7.2222222222e-01, 6.1111111111e-01, 0.0, 0.0, 0.0, ],
[7.1527777778e-01, 4.5138888889e-01, 0.0, 0.0, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ],
[2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ],
[7.4014868308e-17, 1.000e00, 3.3333333333e-01, 0.0, 0.0, ],
[9.2518585385e-18, 3.3333333333e-01, 6.25e-01, 2.0833333333e-01, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ],
[0.0, 4.4444444444e-01, 3.8888888889e-01, 4.4444444444e-01, 5.5555555556e-02, ],
[0.0, 5.5555555556e-02, 4.8611111111e-02, 4.3055555556e-01, 6.3194444444e-01, ]]]],
device=device, dtype=self.dtype)
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIAlign on CPU'
def test_ps_roi_align_gradcheck_cpu(self):
device = torch.device('cpu')
pool_size = 5
x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)
m = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1,
sampling_ratio=2).to(dtype=self.dtype, device=device)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIAlign on CPU'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIAlign on CPU'
@torch.jit.script
def script_func(input, rois):
return ops.ps_roi_align(input, rois, 5, 2.0, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted ps_roi_align on CPU'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_align_basic_cuda(self):
device = torch.device('cuda')
pool_size = 3
x = torch.rand(1, 2 * (pool_size ** 2), 7, 7, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 5, 5]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2)
y = ps_roi_align(x, rois)
gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect'
y = ps_roi_align(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=-1,
dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_align_cuda(self):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pool_size = 5
x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2)
y = ps_roi_align(x, rois)
gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect'
y = ps_roi_align(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w,
device, spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_align_gradient_cuda(self):
device = torch.device('cuda')
pool_size = 3
layer = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1,
sampling_ratio=-1).to(dtype=self.dtype, device=device)
x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 4, 4],
[0, 0, 3, 5, 5],
[0, 1, 0, 2, 4]],
dtype=self.dtype, device=device)
y = layer(x, rois)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[8.125e-01, 6.875e-01, 0.0, 0.0, 0.0, ],
[2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ],
[1.0416666667e-01, 6.25e-02, 0.0, 0.0, 0.0, ],
[5.2083333333e-01, 3.125e-01, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ]],
[[8.3266726847e-17, 1.125e00, 3.750e-01, 0.0, 0.0, ],
[2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ],
[0.0, 3.4722222222e-02, 9.7222222222e-02, 3.4722222222e-02, 0.0, ],
[0.0, 1.7361111111e-01, 4.8611111111e-01, 1.7361111111e-01, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ]],
[[0.0, 5.000e-01, 4.375e-01, 5.000e-01, 6.25e-02, ],
[0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ],
[0.0, 0.0, 0.0, 6.25e-02, 1.0416666667e-01, ],
[0.0, 0.0, 0.0, 3.125e-01, 5.2083333333e-01, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ],
[5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ],
[3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ],
[3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ],
[5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ],
[0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ],
[0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ],
[0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ],
[0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ],
[0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ],
[2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ],
[7.2222222222e-01, 6.1111111111e-01, 0.0, 0.0, 0.0, ],
[7.1527777778e-01, 4.5138888889e-01, 0.0, 0.0, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ],
[2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ],
[7.4014868308e-17, 1.000e00, 3.3333333333e-01, 0.0, 0.0, ],
[9.2518585385e-18, 3.3333333333e-01, 6.25e-01, 2.0833333333e-01, 0.0, ]],
[[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 0.0, 0.0, 0.0, 0.0, ],
[0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ],
[0.0, 4.4444444444e-01, 3.8888888889e-01, 4.4444444444e-01, 5.5555555556e-02, ],
[0.0, 5.5555555556e-02, 4.8611111111e-02, 4.3055555556e-01, 6.3194444444e-01, ]]]],
device=device, dtype=self.dtype)
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIAlign'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_align_gradcheck_cuda(self):
device = torch.device('cuda')
pool_size = 5
x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)
m = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1,
sampling_ratio=2).to(dtype=self.dtype, device=device)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIAlign CUDA'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIAlign CUDA'
@torch.jit.script
def script_func(input, rois):
return ops.ps_roi_align(input, rois, 5, 2.0, 1)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted ps_roi_align on CUDA'
class PSRoIPoolTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float64
def slow_ps_roi_pooling(self, x, rois, pool_h, pool_w, device, spatial_scale=1,
dtype=torch.float64):
if device is None:
device = torch.device("cpu")
num_input_channels = x.size(1)
assert num_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw"
num_output_channels = int(num_input_channels / (pool_h * pool_w))
y = torch.zeros(rois.size(0), num_output_channels, pool_h, pool_w, dtype=dtype, device=device)
rois = torch.round(rois * spatial_scale).int()
for n in range(0, x.size(0)):
for r, roi in enumerate(rois):
if roi[0] != n:
continue
c_in = 0
for c_out in range(0, num_output_channels):
roi_height = max(roi[4].item() - roi[2].item(), 1)
roi_width = max(roi[3].item() - roi[1].item(), 1)
bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w)
for j in range(0, pool_h):
start_h = int(np.floor(j * bin_h)) + roi[2].item()
end_h = int(np.ceil((j + 1) * bin_w)) + roi[2].item()
# range-check
start_h = min(max(start_h, 0), x.size(2))
end_h = min(max(end_h, 0), x.size(2))
for i in range(0, pool_w):
start_w = int(np.floor(i * bin_w)) + roi[1].item()
end_w = int(np.ceil((i + 1) * bin_w)) + roi[1].item()
# range-check
start_w = min(max(start_w, 0), x.size(3))
end_w = min(max(end_w, 0), x.size(3))
is_empty = (end_h <= start_h) or (end_w <= start_w)
area = (end_h - start_h) * (end_w - start_w)
if not is_empty:
t = torch.sum(x[n, c_in, slice(start_h, end_h), slice(start_w, end_w)])
y[r, c_out, j, i] = t / area
c_in += 1
return y
def test_ps_roi_pool_basic_cpu(self):
device = torch.device('cpu')
pool_size = 3
x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1)
y = ps_roi_pool(x, rois)
gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU'
y = ps_roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU'
def test_ps_roi_pool_cpu(self):
device = torch.device('cpu')
pool_size = 5
x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1)
y = ps_roi_pool(x, rois)
gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU'
y = ps_roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU'
def test_ps_roi_pool_gradient_cpu(self):
device = torch.device('cpu')
pool_size = 3
layer = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device)
x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 4, 4],
[0, 0, 3, 5, 5],
[0, 1, 0, 2, 4]],
dtype=self.dtype, device=device)
y = layer(x, rois)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.5000, 0.5000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.2500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.2500, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000],
[0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.2500, 0.2500],
[0.0000, 0.0000, 0.0000, 0.2500, 0.2500]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.5000, 0.5000]]]],
device=device, dtype=self.dtype)
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIPool on CPU'
def test_ps_roi_pool_gradcheck_cpu(self):
device = torch.device('cpu')
pool_size = 5
x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)
m = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIPool on CPU'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIPool on CPU'
@torch.jit.script
def script_func(input, rois):
return ops.ps_roi_pool(input, rois, 5, 1.0)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted ps_roi_pool on CPU'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_pool_basic_cuda(self):
device = torch.device('cuda')
pool_size = 3
x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1)
y = ps_roi_pool(x, rois)
gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect'
y = ps_roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_pool_cuda(self):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pool_size = 5
x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1)
y = ps_roi_pool(x, rois)
gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect'
y = ps_roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_pool_gradient_cuda(self):
device = torch.device('cuda')
pool_size = 3
layer = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device)
x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 4, 4],
[0, 0, 3, 5, 5],
[0, 1, 0, 2, 4]],
dtype=self.dtype, device=device)
y = layer(x, rois)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.5000, 0.5000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.2500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.2500, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000],
[0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.2500, 0.2500],
[0.0000, 0.0000, 0.0000, 0.2500, 0.2500]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.2500, 0.7500, 0.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 0.7500, 0.2500, 0.0000, 0.0000],
[0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000]],
[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.5000, 0.2500, 0.2500, 0.0000],
[0.0000, 0.0000, 0.0000, 0.5000, 0.5000]]]],
device=device, dtype=self.dtype)
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIPool'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_pool_gradcheck_cuda(self):
device = torch.device('cuda')
pool_size = 5
x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)
m = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIPool CUDA'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIPool CUDA'
@torch.jit.script
def script_func(input, rois):
return ops.ps_roi_pool(input, rois, 5, 1.0)[0]
assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted ps_roi_pool on CUDA'
class NMSTester(unittest.TestCase): class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold): def reference_nms(self, boxes, scores, iou_threshold):
""" """
......
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#include <iostream>
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return PSROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return PSROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}
at::Tensor PSROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return PSROIAlign_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return PSROIAlign_backward_cpu(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
}
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class PSROIAlignFunction
: public torch::autograd::Function<PSROIAlignFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIAlign_forward(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIAlign_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
ctx->saved_data["sampling_ratio"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {
grad_in, Variable(), Variable(), Variable(), Variable(), Variable()};
}
};
std::tuple<Tensor, Tensor> ps_roi_align(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return PSROIPool_forward_cuda(
input, rois, spatial_scale, pooled_height, pooled_width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return PSROIPool_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width);
}
at::Tensor PSROIPool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return PSROIPool_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return PSROIPool_backward_cpu(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
}
using namespace at;
using torch::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class PSROIPoolFunction : public torch::autograd::Function<PSROIPoolFunction> {
public:
static variable_list forward(
AutogradContext* ctx,
Variable input,
Variable rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["input_shape"] = input.sizes();
auto result = PSROIPool_forward(
input, rois, spatial_scale, pooled_height, pooled_width);
auto output = std::get<0>(result);
auto channel_mapping = std::get<1>(result);
ctx->save_for_backward({rois, channel_mapping});
ctx->mark_non_differentiable({channel_mapping});
return {output, channel_mapping};
}
static variable_list backward(
AutogradContext* ctx,
variable_list grad_output) {
// Use data saved in forward
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = PSROIPool_backward(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
std::tuple<Tensor, Tensor> ps_roi_pool(
const Tensor& input,
const Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width) {
auto result = PSROIPoolFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width);
return std::tuple<Tensor, Tensor>(result[0], result[1]);
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
template <typename T>
T bilinear_interpolate(
const T* input,
const int height,
const int width,
T y,
T x,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
void PSROIAlignForwardCPU(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* rois,
const int channels_out,
T* output,
int* channel_mapping) {
int num_rois = nthreads / channels_out / pooled_width / pooled_height;
for (int n = 0; n < num_rois; n++) {
// [start, end) interval for spatial sampling
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int c_in = 0;
for (int c_out = 0; c_out < channels_out; ++c_out) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int index =
((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
pw;
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(
offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
c_in++;
}
}
}
}
}
template <typename T>
void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void PSROIAlignBackwardCPU(
const int nthreads,
const T* grad_output,
const int* channel_mapping,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int channels_out,
T* grad_input,
const T* rois) {
for (int index = 0; index < nthreads; index++) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / channels_out;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
add(grad_input_offset + y_low * width + x_low, g1);
add(grad_input_offset + y_low * width + x_high, g2);
add(grad_input_offset + y_high * width + x_low, g3);
add(grad_input_offset + y_high * width + x_high, g4);
} // if
} // ix
} // iy
}
}
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
// Check if input tensors are CPU tensors
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "PSROIAlign_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
AT_ASSERTM(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCPU<scalar_t>(
output_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
rois.contiguous().data<scalar_t>(),
channels_out,
output.data<scalar_t>(),
channel_mapping.data<int>());
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CPU tensors
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
AT_ASSERTM(
channel_mapping.device().is_cpu(),
"channel_mapping must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "PSROIAlign_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
auto grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
int channels_out = channels / (pooled_height * pooled_width);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCPU<scalar_t>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
channel_mapping.data<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
channels_out,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
return grad_input;
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <algorithm>
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void PSROIPoolForward(
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* rois,
const int channels_out,
const int num_rois,
T* output,
int* channel_mapping) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
int roi_width = std::max(roi_end_w - roi_start_w, 1);
int roi_height = std::max(roi_end_h - roi_start_h, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int c_in = 0;
for (int c_out = 0; c_out < channels_out; ++c_out) {
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend =
static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend =
static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1);
hend = std::min(std::max(hend + roi_start_h, 0), height - 1);
wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1);
wend = std::min(std::max(wend + roi_start_w, 0), width - 1);
bool is_empty = (hend <= hstart) || (wend <= wstart);
const T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
out_sum += offset_input[input_index];
}
}
int index =
((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
pw;
T bin_area = (hend - hstart) * (wend - wstart);
output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
channel_mapping[index] = c_in;
c_in++;
}
}
}
}
}
template <typename T>
void PSROIPoolBackward(
const T* grad_output,
const int* channel_mapping,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int channels_out,
T* grad_input,
const T* rois) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = roundf(offset_rois[1] * spatial_scale);
int roi_start_h = roundf(offset_rois[2] * spatial_scale);
int roi_end_w = roundf(offset_rois[3] * spatial_scale);
int roi_end_h = roundf(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
int roi_width = std::max(roi_end_w - roi_start_w, 1);
int roi_height = std::max(roi_end_h - roi_start_h, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart + roi_start_h, 0), height);
hend = std::min(std::max(hend + roi_start_h, 0), height);
wstart = std::min(std::max(wstart + roi_start_w, 0), width);
wend = std::min(std::max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int c_out = 0; c_out < channels_out; ++c_out) {
int index =
((n * channels_out + c_out) * pooled_height + ph) * pooled_width +
pw;
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val =
is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w;
add(grad_input_offset + grad_input_index, diff_val);
}
}
}
}
}
}
}
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
// Check if input tensors are CPU tensors
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "PSROIPool_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
AT_ASSERTM(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t>(
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois.contiguous().data<scalar_t>(),
channels_out,
num_rois,
output.data<scalar_t>(),
channel_mapping.data<int>());
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CPU tensors
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
AT_ASSERTM(
channel_mapping.device().is_cpu(),
"channel_mapping must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "PSROIPool_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
auto grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
int channels_out = channels / (pooled_height * pooled_width);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t>(
grad.contiguous().data<scalar_t>(),
channel_mapping.data<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
channels_out,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
return grad_input;
}
...@@ -40,6 +40,46 @@ at::Tensor ROIAlign_backward_cpu( ...@@ -40,6 +40,46 @@ at::Tensor ROIAlign_backward_cpu(
const int width, const int width,
const int sampling_ratio); const int sampling_ratio);
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
at::Tensor PSROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
at::Tensor PSROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width);
at::Tensor nms_cpu( at::Tensor nms_cpu(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
......
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
template <typename T>
__device__ T bilinear_interpolate(
const T* input,
const int height,
const int width,
T y,
T x,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__global__ void PSROIAlignForwardCUDA(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* rois,
const int channels_out,
T* output,
int* channel_mapping) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c_out, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c_out = (index / pooled_width / pooled_height) % channels_out;
int n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
int c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
}
}
template <typename T>
__device__ void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <typename T>
__global__ void PSROIAlignBackwardCUDA(
const int nthreads,
const T* grad_output,
const int* channel_mapping,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int channels_out,
T* grad_input,
const T* rois) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / channels_out;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(grad_input_offset + y_low * width + x_low, g1);
atomicAdd(grad_input_offset + y_low * width + x_high, g2);
atomicAdd(grad_input_offset + y_high * width + x_low, g3);
atomicAdd(grad_input_offset + y_high * width + x_high, g4);
} // if
} // ix
} // iy
}
}
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "PSROIAlign_forward_cuda";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
at::cuda::CUDAGuard device_guard(input.device());
auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
AT_ASSERTM(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel();
if (output_size == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, channel_mapping);
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L));
dim3 block(512);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIAlign_forward", [&] {
PSROIAlignForwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
rois.contiguous().data<scalar_t>(),
channels_out,
output.data<scalar_t>(),
channel_mapping.data<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize();
return std::make_tuple(output, channel_mapping);
}
at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(
channel_mapping.type().is_cuda(),
"channel_mapping must be a CUDA tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "PSROIAlign_backward_cuda";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
at::cuda::CUDAGuard device_guard(grad.device());
auto num_rois = rois.size(0);
auto grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients
if (grad.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
int channels_out = channels / (pooled_height * pooled_width);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIAlign_backward", [&] {
PSROIAlignBackwardCUDA<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
channel_mapping.data<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
channels_out,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
template <typename T>
__global__ void PSROIPoolForward(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* rois,
const int channels_out,
T* output,
int* channel_mapping) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c_out, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c_out = (index / pooled_width / pooled_height) % channels_out;
int n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
int c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = roundf(offset_rois[1] * spatial_scale);
int roi_start_h = roundf(offset_rois[2] * spatial_scale);
int roi_end_w = roundf(offset_rois[3] * spatial_scale);
int roi_end_h = roundf(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w, 1);
int roi_height = max(roi_end_h - roi_start_h, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, 0), height - 1);
hend = min(max(hend + roi_start_h, 0), height - 1);
wstart = min(max(wstart + roi_start_w, 0), width - 1);
wend = min(max(wend + roi_start_w, 0), width - 1);
bool is_empty = (hend <= hstart) || (wend <= wstart);
const T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
out_sum += offset_input[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
channel_mapping[index] = c_in;
}
}
template <typename T>
__global__ void PSROIPoolBackward(
const int nthreads,
const T* grad_output,
const int* channel_mapping,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int channels_out,
T* grad_input,
const T* rois) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, *, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int n = index / pooled_width / pooled_height / channels_out;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = roundf(offset_rois[1] * spatial_scale);
int roi_start_h = roundf(offset_rois[2] * spatial_scale);
int roi_end_w = roundf(offset_rois[3] * spatial_scale);
int roi_end_h = roundf(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w, 1);
int roi_height = max(roi_end_h - roi_start_h, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
int c_in = channel_mapping[index];
T* grad_input_offset =
grad_input + (roi_batch_ind * channels + c_in) * height * width;
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w;
atomicAdd(grad_input_offset + grad_input_index, diff_val);
}
}
}
}
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "PSROIPool_forward_cuda";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
at::cuda::CUDAGuard device_guard(input.device());
auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
AT_ASSERTM(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros(
{num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping =
at::zeros(output.sizes(), input.options().dtype(at::kInt));
auto output_size = output.numel();
if (output_size == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, channel_mapping);
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L));
dim3 block(512);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "PSROIPool_forward", [&] {
PSROIPoolForward<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois.contiguous().data<scalar_t>(),
channels_out,
output.data<scalar_t>(),
channel_mapping.data<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, channel_mapping);
}
at::Tensor PSROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(
channel_mapping.type().is_cuda(),
"channel_mapping must be a CUDA tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "PSROIPool_backward_cuda";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
at::cuda::CUDAGuard device_guard(grad.device());
auto num_rois = rois.size(0);
auto grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients
if (grad.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
int channels_out = channels / (pooled_height * pooled_width);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "PSROIPool_backward", [&] {
PSROIPoolBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
channel_mapping.data<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
channels_out,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
...@@ -41,6 +41,46 @@ at::Tensor ROIPool_backward_cuda( ...@@ -41,6 +41,46 @@ at::Tensor ROIPool_backward_cuda(
const int height, const int height,
const int width); const int width);
std::tuple<at::Tensor, at::Tensor> PSROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
at::Tensor PSROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
at::Tensor PSROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width);
at::Tensor nms_cuda( at::Tensor nms_cuda(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <cuda.h> #include <cuda.h>
#endif #endif
#include "PSROIAlign.h"
#include "PSROIPool.h"
#include "ROIAlign.h" #include "ROIAlign.h"
#include "ROIPool.h" #include "ROIPool.h"
#include "nms.h" #include "nms.h"
...@@ -41,4 +43,6 @@ static auto registry = ...@@ -41,4 +43,6 @@ static auto registry =
.op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor", .op("torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor",
&roi_align) &roi_align)
.op("torchvision::roi_pool", &roi_pool) .op("torchvision::roi_pool", &roi_pool)
.op("torchvision::ps_roi_align", &ps_roi_align)
.op("torchvision::ps_roi_pool", &ps_roi_pool)
.op("torchvision::_cuda_version", &_cuda_version); .op("torchvision::_cuda_version", &_cuda_version);
from .boxes import nms, box_iou from .boxes import nms, box_iou
from .roi_align import roi_align, RoIAlign from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool from .roi_pool import roi_pool, RoIPool
from .ps_roi_align import ps_roi_align, PSRoIAlign
from .ps_roi_pool import ps_roi_pool, PSRoIPool
from .poolers import MultiScaleRoIAlign from .poolers import MultiScaleRoIAlign
from .feature_pyramid_network import FeaturePyramidNetwork from .feature_pyramid_network import FeaturePyramidNetwork
...@@ -11,5 +13,6 @@ _register_custom_op() ...@@ -11,5 +13,6 @@ _register_custom_op()
__all__ = [ __all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', 'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool',
'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', 'PSRoIPool',
'MultiScaleRoIAlign', 'FeaturePyramidNetwork' 'MultiScaleRoIAlign', 'FeaturePyramidNetwork'
] ]
import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from ._utils import convert_boxes_to_roi_format
def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1):
# type: (Tensor, Tensor, int, float, int) -> Tensor
"""
Performs Position-Sensitive Region of Interest (RoI) Align operator
mentioned in Light-Head R-CNN.
Arguments:
input (Tensor[N, C, H, W]): input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from. If a single Tensor is passed,
then the first column should contain the batch index. If a list of Tensors
is passed, then each Tensor will correspond to the boxes for an element i
in a batch
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
spatial_scale (float): a scaling factor that maps the input coordinates to
the box coordinates. Default: 1.0
sampling_ratio (int): number of sampling points in the interpolation grid
used to compute the output value of each pooled output bin. If > 0
then exactly sampling_ratio x sampling_ratio grid points are used.
If <= 0, then an adaptive number of grid points are used (computed as
ceil(roi_width / pooled_w), and likewise for height). Default: -1
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.ps_roi_align(input, rois, spatial_scale,
output_size[0],
output_size[1],
sampling_ratio)
return output
class PSRoIAlign(nn.Module):
"""
See ps_roi_align
"""
def __init__(self, output_size, spatial_scale, sampling_ratio):
super(PSRoIAlign, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
self.sampling_ratio = sampling_ratio
def forward(self, input, rois):
return ps_roi_align(input, rois, self.output_size, self.spatial_scale,
self.sampling_ratio)
def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ', sampling_ratio=' + str(self.sampling_ratio)
tmpstr += ')'
return tmpstr
import torch
from torch import nn, Tensor
from torch.nn.modules.utils import _pair
from torch.jit.annotations import List
from ._utils import convert_boxes_to_roi_format
def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0):
# type: (Tensor, Tensor, int, float) -> Tensor
"""
Performs Position-Sensitive Region of Interest (RoI) Pool operator
described in R-FCN
Arguments:
input (Tensor[N, C, H, W]): input tensor
boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2)
format where the regions will be taken from. If a single Tensor is passed,
then the first column should contain the batch index. If a list of Tensors
is passed, then each Tensor will correspond to the boxes for an element i
in a batch
output_size (int or Tuple[int, int]): the size of the output after the cropping
is performed, as (height, width)
spatial_scale (float): a scaling factor that maps the input coordinates to
the box coordinates. Default: 1.0
Returns:
output (Tensor[K, C, output_size[0], output_size[1]])
"""
rois = boxes
output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale,
output_size[0],
output_size[1])
return output
class PSRoIPool(nn.Module):
"""
See ps_roi_pool
"""
def __init__(self, output_size, spatial_scale):
super(PSRoIPool, self).__init__()
self.output_size = output_size
self.spatial_scale = spatial_scale
def forward(self, input, rois):
return ps_roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'output_size=' + str(self.output_size)
tmpstr += ', spatial_scale=' + str(self.spatial_scale)
tmpstr += ')'
return tmpstr
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment