Unverified Commit dc3ac290 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Add C++ ops to torchvision (#826)

* Initial layout for layers with cpp extensions

* Move files around

* Fix import after move

* Add support for multiple types to ROIAlign

* Different organization

CUDA extensions work now

* Cleanups

* Reduce memory requirements for backwards

* Replace runtime_error by AT_ERROR

* Add nms test

* Add support for compilation using CPP extensions

* Change folder structure

* Add ROIPool cuda

* Cleanups

* Add roi_pool.py

* Fix lint

* Add initial structures folder for bounding boxes

* Assertion macros compatible with pytorch master (#540)

* Support for ROI Pooling (#592)

* ROI Pooling with tests. Fix for cuda context in ROI Align.

* renamed bottom and top to follow torch conventions

* remove .type().tensor() calls in favor of the new approach to tensor initialization (#626)

* Consistent naming for rois variable (#627)

* remove .type().tensor() calls in favor of the new approach to tensor initialization

* Consistent naming for rois variable in ROIPool

* ROIPool: Support for all datatypes (#632)

* Use of torch7 naming scheme for ROIAlign forward and backward

* use common cuda helpers in ROIAlign

* use .options() in favor of .type() where applicable

* Added tests for forward pass of ROIAlign, as well as more consistent naming scheme for CPU vs CUDA

* working ROIAlign cuda backwards pass

* working ROIAlign backwards pass for CPU

* added relevant headers for ROIAlign backwards

* tests for ROIAlign layer

* replace .type() with .options() for tensor initialization in ROIAlign layers

* support for Half types in ROIAlign

* gradcheck tests for ROIAlign

* updated ROIPool on CPU to work with all datatypes

* updated and cleaned tests for ROI Pooling

* Fix rebase problem

* Remove structures folder

* Improve cleanup and bugfix in test_layers

* Update C++ headers

* Add CUDAGuard to cu files

* Add more checks to layers

* Add CUDA NMS and tests

* Add multi-type support for NMS CUDA

* Avoid using THCudaMalloc

* Add clang-format and reformat c++ code

* Remove THC includes

* Rename layers to ops

* Add documentation and rename functions

* Improve the documentation a bit

* Fix some lint errors

* Fix remaining lint inssues

* Area computation doesn't add +1 in NMS

* Update CI to use PyTorch nightly

* Make NMS return indices sorted according to the score

* Address reviewer comments

* Lint fixes

* Improve doc for roi_align and roi_pool

* move to xenial

* Fix bug pointed by @lopuhin

* Fix RoIPool reference implementation in Python 2

Also fixes a bug in the clip_boxes_to_image -- this function needs a test!

* Remove change in .travis
parent 0564df43
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
#CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
...@@ -3,10 +3,19 @@ dist/ ...@@ -3,10 +3,19 @@ dist/
torchvision.egg-info/ torchvision.egg-info/
torchvision/version.py torchvision/version.py
*/**/__pycache__ */**/__pycache__
*/__pycache__
*/*.pyc
*/**/*.pyc */**/*.pyc
*/**/**/*.pyc
*/**/*~ */**/*~
*~ *~
docs/build docs/build
.coverage .coverage
htmlcov htmlcov
.*.swp .*.swp
*.so*
*.dylib*
*/*.so*
*/*.dylib*
*.swp
*.swo
...@@ -6,6 +6,12 @@ import sys ...@@ -6,6 +6,12 @@ import sys
from setuptools import setup, find_packages from setuptools import setup, find_packages
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import get_distribution, DistributionNotFound
import subprocess import subprocess
import distutils.command.clean
import glob
import shutil
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
def read(*names, **kwargs): def read(*names, **kwargs):
...@@ -69,6 +75,55 @@ pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' ...@@ -69,6 +75,55 @@ pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow'
requirements.append(pillow_req + pillow_ver) requirements.append(pillow_req + pillow_ver)
def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc')
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
sources = main_file + source_cpu
extension = CppExtension
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [('WITH_CUDA', None)]
sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]
ext_modules = [
extension(
'torchvision._C',
sources,
include_dirs=include_dirs,
define_macros=define_macros,
)
]
return ext_modules
class clean(distutils.command.clean.clean):
def run(self):
with open('.gitignore', 'r') as f:
ignores = f.read()
for wildcard in filter(None, ignores.split('\n')):
for filename in glob.glob(wildcard):
try:
os.remove(filename)
except OSError:
shutil.rmtree(filename, ignore_errors=True)
# It's an old-style class in Python 2.7...
distutils.command.clean.clean.run(self)
setup( setup(
# Metadata # Metadata
name=package_name, name=package_name,
...@@ -88,4 +143,6 @@ setup( ...@@ -88,4 +143,6 @@ setup(
extras_require={ extras_require={
"scipy": ["scipy"], "scipy": ["scipy"],
}, },
ext_modules=get_extensions(),
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, 'clean': clean}
) )
import numpy as np
import torch
from torch.autograd import gradcheck
from torchvision import ops
from itertools import product
import unittest
class RoIPoolTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dtype = torch.float64
def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
device=torch.device('cpu'), dtype=torch.float64):
c = x.size(1)
y = torch.zeros(rois.size(0), c, pool_h, pool_w, dtype=dtype, device=device)
rois = torch.round(rois * spatial_scale)
for n in range(0, y.size(0)):
for r, roi in enumerate(rois):
if roi[0] == n:
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
roi_x = x[roi[0].long(), :, start_h:end_h, start_w:end_w]
bin_h, bin_w = roi_x.size(-2) / float(pool_h), roi_x.size(-1) / float(pool_w)
for j in range(0, pool_h):
cj = slice(int(np.floor(j * bin_h)), int(np.ceil((j + 1) * bin_h)))
for i in range(0, pool_w):
ci = slice(int(np.floor(i * bin_w)), int(np.ceil((i + 1) * bin_w)))
t = roi_x[:, cj, ci].reshape(c, -1)
if t.numel() > 0:
y[r, :, j, i] = torch.max(t, 1)[0]
return y
def test_roi_pool_basic_cpu(self):
device = torch.device('cpu')
x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (5, 5)
roi_pool = ops.RoIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)
gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU'
# non-contiguous
y = roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU'
def test_roi_pool_cpu(self):
device = torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
pool_h, pool_w = (5, 5)
roi_pool = ops.RoIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)
gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU for batch > 1'
# non-contiguous
y = roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU for batch > 1'
def test_roi_pool_cpu_empty_rois(self):
device = torch.device('cpu')
x = torch.tensor(
[[[[0.1767, 1.2851, 4.2325, 4.8645, 7.1496]],
[[2.5916, 4.3361, 3.8143, 6.1329, 2.0230]],
[[1.4492, 3.3384, 4.0816, 6.3116, 5.1068]]]],
dtype=self.dtype, device=device)
rois = torch.tensor(
[[0., 1., 0., 4., 0.],
[0., 2., 0., 3., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 2., 0., 2., 0.]],
dtype=self.dtype, device=device)
pool_h, pool_w = (1, 2)
roi_pool = ops.RoIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)
gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU empty rois'
# non-contiguous
y = roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU for empty rois non-contiguous'
def test_roi_pool_gradient_cpu(self):
device = torch.device('cpu')
x = torch.ones(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=self.dtype, device=device)
layer = ops.RoIPool((5, 5), 1).to(dtype=self.dtype, device=device)
y = layer(x, rois)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]],
device=device, dtype=self.dtype)
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for roi_pool'
def test_roi_pool_gradcheck_cpu(self):
device = torch.device('cpu')
x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)
m = ops.RoIPool((5, 5), 1).to(dtype=self.dtype, device=device)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CPU'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CPU'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_basic_cuda(self):
device = torch.device('cuda')
x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (5, 5)
roi_pool = ops.RoIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)
gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'RoIPool layer incorrect'
y = roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'RoIPool layer incorrect'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_cuda(self):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = torch.rand(2, 1, 10, 10, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
pool_h, pool_w = (5, 5)
roi_pool = ops.RoIPool((pool_h, pool_w), 1)
y = roi_pool(x, rois)
gt_y = self.slow_roi_pooling(x, rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'RoIPool layer incorrect'
y = roi_pool(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'RoIPool layer incorrect'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gradient_cuda(self):
device = torch.device('cuda')
layer = ops.RoIPool((5, 5), 1).to(dtype=self.dtype, device=device)
x = torch.ones(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 4, 9],
[0, 0, 0, 4, 4]],
dtype=self.dtype, device=device)
y = layer(x, rois)
s = y.sum()
s.backward()
gt_grad = torch.tensor([[[[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[2., 1., 2., 1., 2., 0., 1., 0., 1., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]]]],
device=device, dtype=self.dtype)
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for roi_pool'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_pool_gradcheck_cuda(self):
device = torch.device('cuda')
x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)
m = ops.RoIPool((5, 5), 1).to(dtype=self.dtype, device=device)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for roi_pool CUDA'
assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for roi_pool CUDA'
class RoIAlignTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
torch.manual_seed(123)
cls.dtype = torch.float32
cls.x = torch.rand(1, 1, 10, 10, dtype=cls.dtype)
cls.single_roi = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy)
dtype=cls.dtype)
cls.rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9]],
dtype=cls.dtype)
cls.gt_y_single = torch.tensor(
[[[[0.41617328, 0.5040753, 0.25266218, 0.4296828, 0.29928464],
[0.5210769, 0.57222337, 0.2524979, 0.32063985, 0.32635176],
[0.73108256, 0.6114335, 0.62033176, 0.8188273, 0.5562218],
[0.83115816, 0.70803946, 0.7084047, 0.74928707, 0.7769296],
[0.54266506, 0.45964524, 0.5780159, 0.80522037, 0.7321807]]]], dtype=cls.dtype)
cls.gt_y_multiple = torch.tensor(
[[[[0.49311584, 0.35972416, 0.40843594, 0.3638034, 0.49751836],
[0.70881474, 0.75481665, 0.5826779, 0.34767765, 0.46865487],
[0.4740328, 0.69306874, 0.3617804, 0.47145438, 0.66130304],
[0.6861706, 0.17634538, 0.47194335, 0.42473823, 0.37930614],
[0.62666404, 0.49973848, 0.37911576, 0.5842756, 0.7176864]]],
[[[0.67499936, 0.6607055, 0.42656037, 0.46134934, 0.42144877],
[0.7471722, 0.7235433, 0.14512213, 0.13031253, 0.289369],
[0.8443615, 0.6659734, 0.23614208, 0.14719573, 0.4268827],
[0.69429564, 0.5621515, 0.5019923, 0.40678093, 0.34556213],
[0.51315194, 0.7177093, 0.6494485, 0.6775592, 0.43865064]]],
[[[0.24465509, 0.36108392, 0.64635646, 0.4051828, 0.33956185],
[0.49006107, 0.42982674, 0.34184104, 0.15493104, 0.49633422],
[0.54400194, 0.5265246, 0.22381854, 0.3929715, 0.6757667],
[0.32961223, 0.38482672, 0.68877804, 0.71822757, 0.711909],
[0.561259, 0.71047884, 0.84651315, 0.8541089, 0.644432]]]], dtype=cls.dtype)
cls.x_grad = torch.tensor(
[[[[0.075625, 0.15125, 0.15124999, 0.15125002, 0.15812504,
0.15812503, 0.15124999, 0.15124999, 0.15125006, 0.0756249],
[0.15125, 0.30250007, 0.3025, 0.30250007, 0.31625012,
0.31625003, 0.3025, 0.3025, 0.30250013, 0.1512498],
[0.15124999, 0.3025, 0.30249995, 0.3025, 0.31625006,
0.31625, 0.30249995, 0.30249995, 0.30250007, 0.15124978],
[0.15125002, 0.30250007, 0.3025, 0.30250007, 0.31625012,
0.3162501, 0.3025, 0.3025, 0.30250013, 0.15124981],
[0.15812504, 0.31625012, 0.31625006, 0.31625012, 0.33062524,
0.3306251, 0.31625006, 0.31625006, 0.3162502, 0.15812483],
[0.5181251, 1.0962502, 1.0362502, 1.0962503, 0.69062525, 0.6906252,
1.0962502, 1.0362502, 1.0962503, 0.5181248],
[0.93125, 1.9925, 1.8624997, 1.9925, 1.0962502, 1.0962502,
1.9925, 1.8624998, 1.9925, 0.9312496],
[0.8712501, 1.8625, 1.7425002, 1.8625001, 1.0362502, 1.0362502,
1.8625, 1.7425001, 1.8625002, 0.8712497],
[0.93125004, 1.9925, 1.8625002, 1.9925, 1.0962503, 1.0962503,
1.9925001, 1.8625001, 1.9925001, 0.93124974],
[0.43562484, 0.9312497, 0.8712497, 0.9312497, 0.5181249, 0.5181248,
0.9312496, 0.8712497, 0.93124974, 0.43562466]]]], dtype=cls.dtype)
def test_roi_align_basic_cpu(self):
device = torch.device('cpu')
x = self.x.to(device)
single_roi = self.single_roi.to(device)
gt_y_single = self.gt_y_single.to(device)
pool_h, pool_w = (5, 5)
roi_align = ops.RoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
y = roi_align(x, single_roi)
assert torch.allclose(gt_y_single, y), 'RoIAlign layer incorrect for single ROI on CPU'
y = roi_align(x.transpose(2, 3).contiguous().transpose(2, 3), single_roi)
assert torch.allclose(gt_y_single, y), 'RoIAlign layer incorrect for single ROI on CPU'
def test_roi_align_cpu(self):
device = torch.device('cpu')
x = self.x.to(device)
rois = self.rois.to(device)
gt_y_multiple = self.gt_y_multiple.to(device)
pool_h, pool_w = (5, 5)
roi_align = ops.RoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
y = roi_align(x, rois)
assert torch.allclose(gt_y_multiple, y), 'RoIAlign layer incorrect for multiple ROIs on CPU'
y = roi_align(x.transpose(2, 3).contiguous().transpose(2, 3), rois)
assert torch.allclose(gt_y_multiple, y), 'RoIAlign layer incorrect for multiple ROIs on CPU'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_basic_cuda(self):
device = torch.device('cuda')
x = self.x.to(device)
single_roi = self.single_roi.to(device)
gt_y_single = self.gt_y_single.to(device)
pool_h, pool_w = (5, 5)
roi_align = ops.RoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
y = roi_align(x, single_roi)
assert torch.allclose(gt_y_single, y), 'RoIAlign layer incorrect for single ROI on CUDA'
y = roi_align(x.transpose(2, 3).contiguous().transpose(2, 3), single_roi)
assert torch.allclose(gt_y_single, y), 'RoIAlign layer incorrect for single ROI on CUDA'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_cuda(self):
device = torch.device('cuda')
x = self.x.to(device)
rois = self.rois.to(device)
gt_y_multiple = self.gt_y_multiple.to(device)
pool_h, pool_w = (5, 5)
roi_align = ops.RoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
y = roi_align(x, rois)
assert torch.allclose(gt_y_multiple, y), 'RoIAlign layer incorrect for multiple ROIs on CUDA'
y = roi_align(x.transpose(2, 3).contiguous().transpose(2, 3), rois)
assert torch.allclose(gt_y_multiple, y), 'RoIAlign layer incorrect for multiple ROIs on CUDA'
def test_roi_align_gradient_cpu(self):
"""
Compute gradients for RoIAlign with multiple bounding boxes on CPU
"""
device = torch.device('cpu')
pool_h, pool_w = (5, 5)
roi_align = ops.RoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
x = self.x.to(device).clone()
rois = self.rois.to(device)
gt_grad = self.x_grad.to(device)
x.requires_grad = True
y = roi_align(x, rois)
s = y.sum()
s.backward()
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for RoIAlign CPU'
def test_roi_align_gradcheck_cpu(self):
dtype = torch.float64
device = torch.device('cpu')
m = ops.RoIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device)
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
rois = self.rois.to(device=device, dtype=dtype)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CPU'
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CPU'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_gradient_cuda(self):
"""
Compute gradients for RoIAlign with multiple bounding boxes on the GPU
"""
device = torch.device('cuda')
pool_h, pool_w = (5, 5)
roi_align = ops.RoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2).to(device=device)
x = self.x.to(device).clone()
rois = self.rois.to(device)
gt_grad = self.x_grad.to(device)
x.requires_grad = True
y = roi_align(x, rois)
s = y.sum()
s.backward()
assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for RoIAlign CUDA'
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_gradcheck_cuda(self):
dtype = torch.float64
device = torch.device('cuda')
m = ops.RoIAlign((5, 5), 0.5, 1).to(dtype=dtype, device=device)
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=True)
rois = self.rois.to(device=device, dtype=dtype)
def func(input):
return m(input, rois)
assert gradcheck(func, (x,)), 'gradcheck failed for RoIAlign CUDA'
assert gradcheck(func, (x.transpose(2, 3),)), 'gradcheck failed for RoIAlign CUDA'
class NMSTester(unittest.TestCase):
def reference_nms(self, boxes, scores, iou_threshold):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
Returns:
picked: a list of indexes of the kept boxes
"""
picked = []
_, indexes = scores.sort(descending=True)
while len(indexes) > 0:
current = indexes[0]
picked.append(current.item())
if len(indexes) == 1:
break
current_box = boxes[current, :]
indexes = indexes[1:]
rest_boxes = boxes[indexes, :]
iou = ops.box_iou(rest_boxes, current_box.unsqueeze(0)).squeeze(1)
indexes = indexes[iou <= iou_threshold]
return torch.as_tensor(picked)
def _create_tensors(self, N):
boxes = torch.rand(N, 4) * 100
boxes[:, 2:] += torch.rand(N, 2) * 100
scores = torch.rand(N)
return boxes, scores
def test_nms(self):
boxes, scores = self._create_tensors(1000)
err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}'
for iou in [0.2, 0.5, 0.8]:
keep_ref = self.reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
assert torch.allclose(keep, keep_ref), err_msg.format(iou)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda(self):
boxes, scores = self._create_tensors(1000)
err_msg = 'NMS incompatible between CPU and CUDA for IoU={}'
for iou in [0.2, 0.5, 0.8]:
r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou)
assert torch.allclose(r_cpu, r_cuda.cpu()), err_msg.format(iou)
if __name__ == '__main__':
unittest.main()
from torchvision import models from torchvision import models
from torchvision import datasets from torchvision import datasets
from torchvision import ops
from torchvision import transforms from torchvision import transforms
from torchvision import utils from torchvision import utils
......
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
at::Tensor ROIAlign_forward(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
const float spatial_scale, // The scale of the image features. ROIs will be
// scaled to this.
const int pooled_height, // The height of the pooled feature map.
const int pooled_width, // The width of the pooled feature
const int sampling_ratio) // The number of points to sample in each bin
// along each axis.
{
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}
at::Tensor ROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign_backward_cuda(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_backward_cpu(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio);
}
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
std::tuple<at::Tensor, at::Tensor> ROIPool_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
if (input.type().is_cuda()) {
#ifdef WITH_CUDA
return ROIPool_forward_cuda(
input, rois, spatial_scale, pooled_height, pooled_width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIPool_forward_cpu(
input, rois, spatial_scale, pooled_height, pooled_width);
}
at::Tensor ROIPool_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.type().is_cuda()) {
#ifdef WITH_CUDA
return ROIPool_backward_cuda(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIPool_backward_cpu(
grad,
rois,
argmax,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width);
}
\ No newline at end of file
#include <ATen/TensorUtils.h>
#include "cpu/vision.h"
// implementation taken from Caffe2
template <typename T>
struct PreCalc {
int pos1;
int pos2;
int pos3;
int pos4;
T w1;
T w2;
T w3;
T w4;
};
template <typename T>
void pre_calc_for_bilinear_interpolate(
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int iy_upper,
const int ix_upper,
T roi_start_h,
T roi_start_w,
T bin_size_h,
T bin_size_w,
int roi_bin_grid_h,
int roi_bin_grid_w,
std::vector<PreCalc<T>>& pre_calc) {
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
for (int iy = 0; iy < iy_upper; iy++) {
const T yy = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < ix_upper; ix++) {
const T xx = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T x = xx;
T y = yy;
// deal with: inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
PreCalc<T> pc;
pc.pos1 = 0;
pc.pos2 = 0;
pc.pos3 = 0;
pc.pos4 = 0;
pc.w1 = 0;
pc.w2 = 0;
pc.w3 = 0;
pc.w4 = 0;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
continue;
}
if (y <= 0) {
y = 0;
}
if (x <= 0) {
x = 0;
}
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
// save weights and indeces
PreCalc<T> pc;
pc.pos1 = y_low * width + x_low;
pc.pos2 = y_low * width + x_high;
pc.pos3 = y_high * width + x_low;
pc.pos4 = y_high * width + x_high;
pc.w1 = w1;
pc.w2 = w2;
pc.w3 = w3;
pc.w4 = w4;
pre_calc[pre_calc_index] = pc;
pre_calc_index += 1;
}
}
}
}
}
template <typename T>
void ROIAlignForward(
const int nthreads,
const T* input,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* rois,
T* output) {
int n_rois = nthreads / channels / pooled_width / pooled_height;
// (n, c, ph, pw) is an element in the pooled output
// can be parallelized using omp
// #pragma omp parallel for num_threads(32)
for (int n = 0; n < n_rois; n++) {
int index_n = n * channels * pooled_width * pooled_height;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
// T roi_start_w = round(offset_rois[0] * spatial_scale);
// T roi_start_h = round(offset_rois[1] * spatial_scale);
// T roi_end_w = round(offset_rois[2] * spatial_scale);
// T roi_end_h = round(offset_rois[3] * spatial_scale);
// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
// we want to precalculate indeces and weights shared by all chanels,
// this is the key point of optimiation
std::vector<PreCalc<T>> pre_calc(
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
pre_calc_for_bilinear_interpolate(
height,
width,
pooled_height,
pooled_width,
roi_bin_grid_h,
roi_bin_grid_w,
roi_start_h,
roi_start_w,
bin_size_h,
bin_size_w,
roi_bin_grid_h,
roi_bin_grid_w,
pre_calc);
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * pooled_width * pooled_height;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
int pre_calc_index = 0;
for (int ph = 0; ph < pooled_height; ph++) {
for (int pw = 0; pw < pooled_width; pw++) {
int index = index_n_c + ph * pooled_width + pw;
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
PreCalc<T> pc = pre_calc[pre_calc_index];
output_val += pc.w1 * offset_input[pc.pos1] +
pc.w2 * offset_input[pc.pos2] +
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
pre_calc_index += 1;
}
}
output_val /= count;
output[index] = output_val;
} // for pw
} // for ph
} // for c
} // for n
}
template <typename T>
void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void ROIAlignBackward(
const int nthreads,
const T* grad_output,
const T& spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
for (int index = 0; index < nthreads; index++) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
// atomic add is not needed for now since it is single threaded
add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // for
} // ROIAlignBackward
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIAlign_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
auto output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0)
return output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
ROIAlignForward<scalar_t>(
output_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
rois.contiguous().data<scalar_t>(),
output.data<scalar_t>());
});
return output;
}
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio) {
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIAlign_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] {
ROIAlignBackward<scalar_t>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
return grad_input;
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <TH/TH.h>
#include <algorithm>
template <class T>
inline void add(T* address, const T& val) {
*address += val;
}
template <typename T>
void RoIPoolForward(
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* rois,
const int num_rois,
T* output,
int* argmax_data) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_width = std::max(roi_end_w - roi_start_w + 1, 1);
int roi_height = std::max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = std::min(std::max(hstart + roi_start_h, 0), height);
hend = std::min(std::max(hend + roi_start_h, 0), height);
wstart = std::min(std::max(wstart + roi_start_w, 0), width);
wend = std::min(std::max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
for (int c = 0; c < channels; ++c) {
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;
const T* input_offset =
input + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
if (input_offset[input_index] > maxval) {
maxval = input_offset[input_index];
maxidx = input_index;
}
}
}
int index =
((n * channels + c) * pooled_height + ph) * pooled_width + pw;
output[index] = maxval;
argmax_data[index] = maxidx;
} // channels
} // pooled_width
} // pooled_height
} // num_rois
}
template <typename T>
void RoIPoolBackward(
const T* grad_output,
const int* argmax_data,
const int num_rois,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
for (int n = 0; n < num_rois; ++n) {
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
for (int c = 0; c < channels; ++c) {
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);
const int* argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
for (int ph = 0; ph < pooled_height; ++ph) {
for (int pw = 0; pw < pooled_width; ++pw) {
int output_offset = n * n_stride + c * c_stride;
int argmax = argmax_data_offset[ph * pooled_width + pw];
if (argmax != -1) {
add(grad_input_offset + argmax,
static_cast<T>(
grad_output
[output_offset + ph * h_stride + pw * w_stride]));
}
} // pooled_width
} // pooled_height
} // channels
} // num_rois
}
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIPool_forward_cpu";
at::checkAllSameType(c, {input_t, rois_t});
int num_rois = rois.size(0);
int channels = input.size(1);
int height = input.size(2);
int width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros(
{num_rois, channels, pooled_height, pooled_width},
input.options().dtype(at::kInt));
if (output.numel() == 0) {
return std::make_tuple(output, argmax);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIPool_forward", [&] {
RoIPoolForward<scalar_t>(
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois.contiguous().data<scalar_t>(),
num_rois,
output.data<scalar_t>(),
argmax.data<int>());
});
return std::make_tuple(output, argmax);
}
at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CPU tensors
AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor");
AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor");
AT_ASSERTM(argmax.device().is_cpu(), "argmax must be a CPU tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIPool_backward_cpu";
at::checkAllSameType(c, {grad_t, rois_t});
auto num_rois = rois.size(0);
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
// handle possibly empty gradients
if (grad.numel() == 0) {
return grad_input;
}
// get stride values to ensure indexing into gradients is correct.
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t>(
grad.contiguous().data<scalar_t>(),
argmax.data<int>(),
num_rois,
channels,
height,
width,
pooled_height,
pooled_width,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
return grad_input;
}
#include "cpu/vision.h"
template <typename scalar_t>
at::Tensor nms_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
AT_ASSERTM(
dets.type() == scores.type(), "dets should have the same type as scores");
if (dets.numel() == 0)
return at::empty({0}, dets.options().dtype(at::kLong));
auto x1_t = dets.select(1, 0).contiguous();
auto y1_t = dets.select(1, 1).contiguous();
auto x2_t = dets.select(1, 2).contiguous();
auto y2_t = dets.select(1, 3).contiguous();
at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data<uint8_t>();
auto keep = keep_t.data<int64_t>();
auto order = order_t.data<int64_t>();
auto x1 = x1_t.data<scalar_t>();
auto y1 = y1_t.data<scalar_t>();
auto x2 = x2_t.data<scalar_t>();
auto y2 = y2_t.data<scalar_t>();
auto areas = areas_t.data<scalar_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1)
continue;
keep[num_to_keep++] = i;
auto ix1 = x1[i];
auto iy1 = y1[i];
auto ix2 = x2[i];
auto iy2 = y2[i];
auto iarea = areas[i];
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1)
continue;
auto xx1 = std::max(ix1, x1[j]);
auto yy1 = std::max(iy1, y1[j]);
auto xx2 = std::min(ix2, x2[j]);
auto yy2 = std::min(iy2, y2[j]);
auto w = std::max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = std::max(static_cast<scalar_t>(0), yy2 - yy1);
auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);
if (ovr >= threshold)
suppressed[j] = 1;
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] {
result = nms_cpu_kernel<scalar_t>(dets, scores, threshold);
});
return result;
}
#pragma once
#include <torch/extension.h>
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
at::Tensor ROIPool_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio);
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold);
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
template <typename T>
__device__ T bilinear_interpolate(
const T* input,
const int height,
const int width,
T y,
T x,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__global__ void RoIAlignForward(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const T* rois,
T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
T output_val = 0.;
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;
output[index] = output_val;
}
}
template <typename T>
__device__ void bilinear_interpolate_gradient(
const int height,
const int width,
T y,
T x,
T& w1,
T& w2,
T& w3,
T& w4,
int& x_low,
int& x_high,
int& y_low,
int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
template <typename T>
__global__ void RoIAlignBackward(
const int nthreads,
const T* grad_output,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale;
T roi_start_h = offset_rois[2] * spatial_scale;
T roi_end_w = offset_rois[3] * spatial_scale;
T roi_end_h = offset_rois[4] * spatial_scale;
// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
T* offset_grad_input =
grad_input + ((roi_batch_ind * channels + c) * height * width);
// We need to index the gradient using the tensor strides to access the
// correct values.
int output_offset = n * n_stride + c * c_stride;
const T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
int x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(
offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(
offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(
offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(
offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // CUDA_1D_KERNEL_LOOP
} // RoIAlignBackward
at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIAlign_forward_cuda";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
at::cuda::CUDAGuard device_guard(input.device());
auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L));
dim3 block(512);
if (output.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return output;
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIAlign_forward", [&] {
RoIAlignForward<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
rois.contiguous().data<scalar_t>(),
output.data<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
return output;
}
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio) {
AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIAlign_backward_cuda";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});
at::cuda::CUDAGuard device_guard(grad.device());
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients
if (grad.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
sampling_ratio,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
template <typename T>
__global__ void RoIPoolForward(
const int nthreads,
const T* input,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const T* rois,
T* output,
int* argmax_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
int roi_start_w = round(offset_rois[1] * spatial_scale);
int roi_start_h = round(offset_rois[2] * spatial_scale);
int roi_end_w = round(offset_rois[3] * spatial_scale);
int roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, 0), height);
hend = min(max(hend + roi_start_h, 0), height);
wstart = min(max(wstart + roi_start_w, 0), width);
wend = min(max(wend + roi_start_w, 0), width);
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
int maxidx = -1;
const T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_index = h * width + w;
if (offset_input[input_index] > maxval) {
maxval = offset_input[input_index];
maxidx = input_index;
}
}
}
output[index] = maxval;
argmax_data[index] = maxidx;
}
}
template <typename T>
__global__ void RoIPoolBackward(
const int nthreads,
const T* grad_output,
const int* argmax_data,
const int num_rois,
const T spatial_scale,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
T* grad_input,
const T* rois,
const int n_stride,
const int c_stride,
const int h_stride,
const int w_stride) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
// (n, c, ph, pw) is an element in the pooled output
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
const T* offset_rois = rois + n * 5;
int roi_batch_ind = offset_rois[0];
T* grad_input_offset =
grad_input + ((roi_batch_ind * channels + c) * height * width);
int output_offset = n * n_stride + c * c_stride;
const int* argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
int argmax = argmax_data_offset[ph * pooled_width + pw];
if (argmax != -1) {
atomicAdd(
grad_input_offset + argmax,
static_cast<T>(
grad_output[output_offset + ph * h_stride + pw * w_stride]));
}
}
}
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width) {
AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ROIPool_forward_cuda";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
at::cuda::CUDAGuard device_guard(input.device());
auto num_rois = rois.size(0);
auto channels = input.size(1);
auto height = input.size(2);
auto width = input.size(3);
at::Tensor output = at::zeros(
{num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros(
{num_rois, channels, pooled_height, pooled_width},
input.options().dtype(at::kInt));
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L));
dim3 block(512);
if (output.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "ROIPool_forward", [&] {
RoIPoolForward<scalar_t><<<grid, block, 0, stream>>>(
output_size,
input.contiguous().data<scalar_t>(),
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
rois.contiguous().data<scalar_t>(),
output.data<scalar_t>(),
argmax.data<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}
at::Tensor ROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width) {
// Check if input tensors are CUDA tensors
AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor");
AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor");
AT_ASSERTM(argmax.device().is_cuda(), "argmax must be a CUDA tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
argmax_t{argmax, "argmax", 3};
at::CheckedFrom c = "ROIPool_backward_cuda";
at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
at::checkAllSameType(c, {grad_t, rois_t});
at::cuda::CUDAGuard device_guard(grad.device());
auto num_rois = rois.size(0);
at::Tensor grad_input =
at::zeros({batch_size, channels, height, width}, grad.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients
if (grad.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
int n_stride = grad.stride(0);
int c_stride = grad.stride(1);
int h_stride = grad.stride(2);
int w_stride = grad.stride(3);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
argmax.contiguous().data<int>(),
num_rois,
spatial_scale,
channels,
height,
width,
pooled_height,
pooled_width,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>(),
n_stride,
c_stride,
h_stride,
w_stride);
});
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
#pragma once
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
i += (blockDim.x * gridDim.x))
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
#include <iostream>
#include <vector>
int const threadsPerBlock = sizeof(unsigned long long) * 8;
template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {
T left = max(a[0], b[0]), right = min(a[2], b[2]);
T top = max(a[1], b[1]), bottom = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(bottom - top, (T)0);
T interS = width * height;
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
return interS / (Sa + Sb - interS);
}
template <typename T>
__global__ void nms_kernel(
const int n_boxes,
const float nms_overlap_thresh,
const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
__shared__ T block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU<T>(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
// boxes is a N x 5 tensor
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) {
using scalar_t = float;
AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes.device());
auto scores = boxes.select(1, 4);
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto boxes_sorted = boxes.index_select(0, order_t);
int boxes_num = boxes.size(0);
const int col_blocks = at::cuda::ATenCeilDiv(boxes_num, threadsPerBlock);
at::Tensor mask =
at::empty({boxes_num * col_blocks}, boxes.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
boxes_sorted.type(), "nms_kernel_cuda", [&] {
nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
boxes_num,
nms_overlap_thresh,
boxes_sorted.data<scalar_t>(),
(unsigned long long*)mask.data<int64_t>());
});
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep =
at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return
order_t
.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
#pragma once
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio);
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio);
std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width);
at::Tensor ROIPool_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width);
at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh);
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
const float threshold) {
if (dets.device().is_cuda()) {
#ifdef WITH_CUDA
if (dets.numel() == 0) {
at::cuda::CUDAGuard device_guard(dets.device());
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto b = at::cat({dets, scores.unsqueeze(1)}, 1);
return nms_cuda(b, threshold);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
at::Tensor result = nms_cpu(dets, scores, threshold);
return result;
}
#include "ROIAlign.h"
#include "ROIPool.h"
#include "nms.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "non-maximum suppression");
m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward");
m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward");
m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward");
m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward");
}
from .boxes import nms, box_iou
from .roi_align import roi_align, RoIAlign
from .roi_pool import roi_pool, RoIPool
__all__ = [
'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool'
]
import torch
def _cat(tensors, dim=0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def convert_boxes_to_roi_format(boxes):
concat_boxes = _cat([b for b in boxes], dim=0)
ids = _cat(
[
torch.full_like(b[:, :1], i)
for i, b in enumerate(boxes)
],
dim=0,
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment