Unverified Commit 16d62e30 authored by Li-Huai (Allan) Lin's avatar Li-Huai (Allan) Lin Committed by GitHub
Browse files

Add MPS kernels for nms and roi ops (#7643)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent f524cd3a
...@@ -48,7 +48,7 @@ except ImportError: ...@@ -48,7 +48,7 @@ except ImportError:
DEVNULL = open(os.devnull, "wb") DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm"
class ExitStatus: class ExitStatus:
......
...@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17) ...@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
file(STRINGS version.txt TORCHVISION_VERSION) file(STRINGS version.txt TORCHVISION_VERSION)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_MPS "Enable MPS support" OFF)
option(WITH_PNG "Enable features requiring LibPNG." ON) option(WITH_PNG "Enable features requiring LibPNG." ON)
option(WITH_JPEG "Enable features requiring LibJPEG." ON) option(WITH_JPEG "Enable features requiring LibJPEG." ON)
option(USE_PYTHON "Link to Python when building" OFF) option(USE_PYTHON "Link to Python when building" OFF)
...@@ -15,6 +16,11 @@ if(WITH_CUDA) ...@@ -15,6 +16,11 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif() endif()
if(WITH_MPS)
enable_language(OBJC OBJCXX)
add_definitions(-DWITH_MPS)
endif()
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
if (WITH_PNG) if (WITH_PNG)
...@@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP ...@@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP
if(WITH_CUDA) if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif() endif()
if(WITH_MPS)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps)
endif()
FOREACH(DIR ${ALLOW_LISTED}) FOREACH(DIR ${ALLOW_LISTED})
file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*) file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*)
......
...@@ -137,10 +137,13 @@ def get_extensions(): ...@@ -137,10 +137,13 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
) )
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
print("Compiling extensions with following flags:") print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1" force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}") print(f" FORCE_CUDA: {force_cuda}")
force_mps = os.getenv("FORCE_MPS", "0") == "1"
print(f" FORCE_MPS: {force_mps}")
debug_mode = os.getenv("DEBUG", "0") == "1" debug_mode = os.getenv("DEBUG", "0") == "1"
print(f" DEBUG: {debug_mode}") print(f" DEBUG: {debug_mode}")
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1" use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
...@@ -202,6 +205,8 @@ def get_extensions(): ...@@ -202,6 +205,8 @@ def get_extensions():
define_macros += [("WITH_HIP", None)] define_macros += [("WITH_HIP", None)]
nvcc_flags = [] nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps
if sys.platform == "win32": if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)] define_macros += [("torchvision_EXPORTS", None)]
......
...@@ -34,6 +34,7 @@ IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS ...@@ -34,6 +34,7 @@ IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
...@@ -130,12 +131,22 @@ def cpu_and_cuda(): ...@@ -130,12 +131,22 @@ def cpu_and_cuda():
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
def needs_cuda(test_func): def needs_cuda(test_func):
import pytest # noqa import pytest # noqa
return pytest.mark.needs_cuda(test_func) return pytest.mark.needs_cuda(test_func)
def needs_mps(test_func):
import pytest # noqa
return pytest.mark.needs_mps(test_func)
def _create_data(height=3, width=3, channels=3, device="cpu"): def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
......
...@@ -8,12 +8,20 @@ import torchvision ...@@ -8,12 +8,20 @@ import torchvision
torchvision.disable_beta_transforms_warning() torchvision.disable_beta_transforms_warning()
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG from common_utils import (
CUDA_NOT_AVAILABLE_MSG,
IN_FBCODE,
IN_OSS_CI,
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
)
def pytest_configure(config): def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems) # register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
...@@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items): ...@@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items):
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark. # and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None
if needs_cuda and not torch.cuda.is_available(): if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU # In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below # There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))
if IN_FBCODE: if IN_FBCODE:
# fbcode doesn't like skipping tests, so instead we just don't collect the test # fbcode doesn't like skipping tests, so instead we just don't collect the test
# so that they don't even "exist", hence the continue statements. # so that they don't even "exist", hence the continue statements.
...@@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items): ...@@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items):
# TODO: something more robust would be to do that only in a sandcastle instance, # TODO: something more robust would be to do that only in a sandcastle instance,
# so that we can still see the test being skipped when testing locally from a devvm # so that we can still see the test being skipped when testing locally from a devvm
continue continue
if needs_mps and not torch.backends.mps.is_available():
# Same as above, but for MPS
continue
elif IN_OSS_CI: elif IN_OSS_CI:
# Here we're not in fbcode, so we can safely collect and skip tests. # Here we're not in fbcode, so we can safely collect and skip tests.
if not needs_cuda and torch.cuda.is_available(): if not needs_cuda and torch.cuda.is_available():
......
...@@ -10,7 +10,7 @@ import pytest ...@@ -10,7 +10,7 @@ import pytest
import torch import torch
import torch.fx import torch.fx
import torch.nn.functional as F import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_cuda, needs_cuda from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image from PIL import Image
from torch import nn, Tensor from torch import nn, Tensor
from torch.autograd import gradcheck from torch.autograd import gradcheck
...@@ -96,12 +96,33 @@ class PoolWrapper(nn.Module): ...@@ -96,12 +96,33 @@ class PoolWrapper(nn.Module):
class RoIOpTester(ABC): class RoIOpTester(ABC):
dtype = torch.float64 dtype = torch.float64
mps_dtype = torch.float32
mps_backward_atol = 2e-2
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): @pytest.mark.parametrize(
x_dtype = self.dtype if x_dtype is None else x_dtype "x_dtype",
rois_dtype = self.dtype if rois_dtype is None else rois_dtype (
torch.float16,
torch.float32,
torch.float64,
),
ids=str,
)
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
if device == "mps" and x_dtype is torch.float64:
pytest.skip("MPS does not support float64")
rois_dtype = x_dtype if rois_dtype is None else rois_dtype
tol = 1e-5
if x_dtype is torch.half:
if device == "mps":
tol = 5e-3
else:
tol = 4e-3
pool_size = 5 pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations. # n_channels % (pool_size ** 2) == 0 required for PS operations.
n_channels = 2 * (pool_size**2) n_channels = 2 * (pool_size**2)
...@@ -120,10 +141,9 @@ class RoIOpTester(ABC): ...@@ -120,10 +141,9 @@ class RoIOpTester(ABC):
# the following should be true whether we're running an autocast test or not. # the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype assert y.dtype == x.dtype
gt_y = self.expected_fn( gt_y = self.expected_fn(
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
) )
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
...@@ -155,16 +175,19 @@ class RoIOpTester(ABC): ...@@ -155,16 +175,19 @@ class RoIOpTester(ABC):
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous, deterministic=False): def test_backward(self, seed, device, contiguous, deterministic=False):
atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
pool_size = 2 pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
if not contiguous: if not contiguous:
x = x.permute(0, 1, 3, 2) x = x.permute(0, 1, 3, 2)
rois = torch.tensor( rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy)
) )
def func(z): def func(z):
...@@ -173,9 +196,25 @@ class RoIOpTester(ABC): ...@@ -173,9 +196,25 @@ class RoIOpTester(ABC):
script_func = self.get_script_fn(rois, pool_size) script_func = self.get_script_fn(rois, pool_size)
with DeterministicGuard(deterministic): with DeterministicGuard(deterministic):
gradcheck(func, (x,)) gradcheck(func, (x,), atol=atol)
gradcheck(script_func, (x,), atol=atol)
gradcheck(script_func, (x,)) @needs_mps
def test_mps_error_inputs(self):
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy)
)
def func(z):
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
with pytest.raises(
RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
):
gradcheck(func, (x,))
@needs_cuda @needs_cuda
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
...@@ -271,6 +310,8 @@ class TestRoiPool(RoIOpTester): ...@@ -271,6 +310,8 @@ class TestRoiPool(RoIOpTester):
class TestPSRoIPool(RoIOpTester): class TestPSRoIPool(RoIOpTester):
mps_backward_atol = 5e-2
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
...@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False): ...@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):
class TestRoIAlign(RoIOpTester): class TestRoIAlign(RoIOpTester):
mps_backward_atol = 6e-2
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
return ops.RoIAlign( return ops.RoIAlign(
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
...@@ -418,10 +461,11 @@ class TestRoIAlign(RoIOpTester): ...@@ -418,10 +461,11 @@ class TestRoIAlign(RoIOpTester):
self._helper_boxes_shape(ops.roi_align) self._helper_boxes_shape(ops.roi_align)
@pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False)) @pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
if deterministic and device == "cpu": if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest") pytest.skip("cpu is always deterministic, don't retest")
super().test_forward( super().test_forward(
...@@ -450,7 +494,7 @@ class TestRoIAlign(RoIOpTester): ...@@ -450,7 +494,7 @@ class TestRoIAlign(RoIOpTester):
) )
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False)) @pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic): def test_backward(self, seed, device, contiguous, deterministic):
...@@ -537,6 +581,8 @@ class TestRoIAlign(RoIOpTester): ...@@ -537,6 +581,8 @@ class TestRoIAlign(RoIOpTester):
class TestPSRoIAlign(RoIOpTester): class TestPSRoIAlign(RoIOpTester):
mps_backward_atol = 5e-2
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
...@@ -705,21 +751,28 @@ class TestNMS: ...@@ -705,21 +751,28 @@ class TestNMS:
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou)) torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
@needs_cuda @pytest.mark.parametrize(
"device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cuda(self, iou, dtype=torch.float64): def test_nms_gpu(self, iou, device, dtype=torch.float64):
dtype = torch.float32 if device == "mps" else dtype
tol = 1e-3 if dtype is torch.half else 1e-5 tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = "NMS incompatible between CPU and CUDA for IoU={}" err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou) boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou) r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou) r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
is_eq = torch.allclose(r_cpu, r_cuda.cpu()) is_eq = torch.allclose(r_cpu, r_gpu.cpu())
if not is_eq: if not is_eq:
# if the indices are not the same, ensure that it's because the scores # if the indices are not the same, ensure that it's because the scores
# are duplicate # are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou) assert is_eq, err_msg.format(iou)
@needs_cuda @needs_cuda
...@@ -727,18 +780,24 @@ class TestNMS: ...@@ -727,18 +780,24 @@ class TestNMS:
@pytest.mark.parametrize("dtype", (torch.float, torch.half)) @pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype): def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_nms_cuda(iou=iou, dtype=dtype) self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
@needs_cuda @pytest.mark.parametrize(
def test_nms_cuda_float16(self): "device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
def test_nms_float16(self, device):
boxes = torch.tensor( boxes = torch.tensor(
[ [
[285.3538, 185.5758, 1193.5110, 851.4551], [285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669], [285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019], [279.2440, 197.9812, 1189.4746, 849.2019],
] ]
).cuda() ).to(device)
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
iou_thres = 0.2 iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres) keep32 = ops.nms(boxes, scores, iou_thres)
......
...@@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl( ...@@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
double iou_threshold) { double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK( TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(), dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores"); "dets should have the same type as scores");
......
constexpr int threadsPerBlock = 512;
template <typename T>
constexpr inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
#include <ATen/native/mps/OperationUtils.h>
namespace vision {
namespace ops {
namespace mps {
static const char* METAL_VISION = R"VISION_METAL(
#include <metal_atomic>
#include <metal_stdlib>
using namespace metal;
/*----------Macros----------*/
#define MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, index_t) \
for (index_t i = (tgid.x * tptg.x) + tid2.x; i < (n); \
i += (tptg.x * n_tgs))
#define MPS_1D_KERNEL_LOOP(i, n, n_tgs) MPS_1D_KERNEL_LOOP_T(i, n, n_tgs, uint)
/*----------Helpers--------*/
template <typename T>
inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
template <typename T>
inline void atomic_add_float( device T* data_ptr, const T val)
{
#if __METAL_VERSION__ >= 300
// atomic_float is supported in Metal 3 (macOS Ventura) onward.
device atomic_fetch_add_explicit((device atomic_float*) data_ptr, val, memory_order_relaxed);
#else
// Custom atomic addition implementation
// https://github.com/ShoYamanishi/AppleNumericalComputing/blob/053f06c1f5a831095c4bcc29aaf11366fce5231e/03_dot/metal/dot.metal#L447-L472
// https://forums.developer.nvidia.com/t/atomicadd-float-float-atomicmul-float-float/14639
// https://on-demand.gputechconf.com/gtc/2013/presentations/S3101-Atomic-Memory-Operations.pdf (See the last slide)
// Create an atomic uint pointer for atomic transaction.
device atomic_uint* atom_var = (device atomic_uint*)data_ptr;
// Create necessary storage.
uint fetched_uint, assigning_uint;
T fetched_float, assigning_float;
// Replace the value in atom_var with 0 and return the previous value in atom_var.
fetched_uint = atomic_exchange_explicit( atom_var, 0 /*desired*/, memory_order_relaxed);
// Read out the previous value as float.
fetched_float = *( (thread T*) &fetched_uint );
// Do addition and represent the addition result in uint for atomic transaction.
assigning_float = fetched_float + val;
assigning_uint = *((thread uint*) &assigning_float);
// atom_var should be 0 now, try to assign the addition result back to the atom_var (data_ptr).
while ((fetched_uint = atomic_exchange_explicit( atom_var, assigning_uint /*desired*/, memory_order_relaxed)) != 0) {
// If atom_var was not 0, i.e. fetched_uint != 0, it means that the data has been modified by other threads.
// Try to assign 0 and get the previously assigned addition result.
uint fetched_uint_again = atomic_exchange_explicit(atom_var, 0 /*desired*/, memory_order_relaxed);
T fetched_float_again = *( (thread T*) &fetched_uint_again );
// Re-add again
fetched_float = *((thread T*) &(fetched_uint));
// Previously assigned addition result + addition result from other threads.
assigning_float = fetched_float_again + fetched_float;
assigning_uint = *( (thread uint*) &assigning_float);
}
#endif
}
template <typename T, typename integer_t>
inline T bilinear_interpolate(
constant T* input,
integer_t height,
integer_t width,
T y,
T x,
uint index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
return 0;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
integer_t y_low = (integer_t)y;
integer_t x_low = (integer_t)x;
integer_t y_high;
integer_t x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T, typename integer_t>
inline void bilinear_interpolate_gradient(
integer_t height,
integer_t width,
T y,
T x,
thread T& w1,
thread T& w2,
thread T& w3,
thread T& w4,
thread integer_t& x_low,
thread integer_t& x_high,
thread integer_t& y_low,
thread integer_t& y_high,
uint index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0)
y = 0;
if (x <= 0)
x = 0;
y_low = (integer_t)y;
x_low = (integer_t)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
}
template <typename T, typename scalar_t>
inline bool IoU(
constant T & a,
threadgroup T & b,
const float threshold) {
auto xx1 = max(a.x, b.x);
auto yy1 = max(a.y, b.y);
auto xx2 = min(a.z, b.z);
auto yy2 = min(a.w, b.w);
auto w = max(static_cast<scalar_t>(0), xx2 - xx1);
auto h = max(static_cast<scalar_t>(0), yy2 - yy1);
// Upcast to float before multiplications to circumvent precision issues in half.
auto inter = static_cast<float>(w) * static_cast<float>(h);
auto area_b = static_cast<float>(b.z - b.x) * static_cast<float>(b.w - b.y);
auto area_a = static_cast<float>(a.z - a.x) * static_cast<float>(a.w - a.y);
return (inter / (area_a + area_b - inter)) > threshold;
}
/*----------Kernels----------*/
// This should be in sync with the one in nms_kernel.mm.
// Since metal does not support dynamic array,
// we need to make it static instead of deriving it from [[threads_per_threadgroup]].
constant int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
template<typename T, typename scalar_t>
kernel void nms(constant T * dev_boxes [[buffer(0)]],
device uint64_t * mask [[buffer(1)]],
constant int64_t & n_boxes [[buffer(2)]],
constant float & iou_threshold [[buffer(3)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tid2 [[thread_position_in_threadgroup]]) {
const uint row_start = tgid.y;
const uint col_start = tgid.x;
const uint tid = tid2.x;
const uint row_size =
min(n_boxes - row_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
const uint col_size =
min(n_boxes - col_start * nmsThreadsPerBlock, nmsThreadsPerBlock);
threadgroup T block_boxes[nmsThreadsPerBlock];
block_boxes[tid] = dev_boxes[nmsThreadsPerBlock * col_start + tid];
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid < row_size) {
const uint cur_box_idx = nmsThreadsPerBlock * row_start + tid;
uint64_t t = 0;
uint start = 0;
if (row_start == col_start) {
start = tid + 1;
}
for (uint i = start; i < col_size; i++){
if (IoU<T, scalar_t>(dev_boxes[cur_box_idx], block_boxes[i], iou_threshold)){
t |= static_cast<uint64_t>(1) << i; // discard 1 keep 0
}
}
const uint col_blocks = ceil_div(n_boxes, nmsThreadsPerBlock);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
#define REGISTER_NMS_OP(DTYPE) \
template \
[[host_name("nms_" #DTYPE)]] \
kernel void nms<DTYPE ## 4, DTYPE>( \
constant DTYPE ## 4 * dev_boxes [[buffer(0)]], \
device uint64_t * mask [[buffer(1)]], \
constant int64_t & n_boxes [[buffer(2)]], \
constant float & iou_threshold [[buffer(3)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
// When the grid is empty, output zeros.
const T count = max(roi_bin_grid_h * roi_bin_grid_w, static_cast<integer_t>(1)); // e.g. = 4
T output_val = 0.;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
output_val += val;
}
}
output_val /= count;
output[index] = output_val;
}
}
#define REGISTER_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_" #DTYPE)]] \
kernel void roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * grad_input [[buffer(2)]],
constant int64_t & output_size [[buffer(3)]],
constant int64_t & channels [[buffer(4)]],
constant int64_t & height [[buffer(5)]],
constant int64_t & width [[buffer(6)]],
constant int64_t & pooled_height [[buffer(7)]],
constant int64_t & pooled_width [[buffer(8)]],
constant int64_t & sampling_ratio [[buffer(9)]],
constant bool & aligned [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
constant int64_t & n_stride [[buffer(12)]],
constant int64_t & c_stride [[buffer(13)]],
constant int64_t & h_stride [[buffer(14)]],
constant int64_t & w_stride [[buffer(15)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T offset = aligned ? (T)0.5 : (T)0.0;
T roi_start_w = offset_rois[1] * spatial_scale - offset;
T roi_start_h = offset_rois[2] * spatial_scale - offset;
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// We need to index the gradient using the tensor strides to access the
// correct values.
const integer_t output_offset = n * n_stride + c * c_stride;
constant T* offset_grad_output = grad_output + output_offset;
const T grad_output_this_bin =
offset_grad_output[ph * h_stride + pw * w_stride];
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
const integer_t input_offset = (roi_batch_ind * channels + c) * height * width;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
{
const T y = roi_start_h + ph * bin_size_h +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = roi_start_w + pw * bin_size_w +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
integer_t x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomic_add_float(grad_input + input_offset + y_low * width + x_low, static_cast<T>(g1));
atomic_add_float(grad_input + input_offset + y_low * width + x_high, static_cast<T>(g2));
atomic_add_float(grad_input + input_offset + y_high * width + x_low, static_cast<T>(g3));
atomic_add_float(grad_input + input_offset + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_align_backward_" #DTYPE)]] \
kernel void roi_align_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * grad_input [[buffer(2)]], \
constant int64_t & output_size [[buffer(3)]], \
constant int64_t & channels [[buffer(4)]], \
constant int64_t & height [[buffer(5)]], \
constant int64_t & width [[buffer(6)]], \
constant int64_t & pooled_height [[buffer(7)]], \
constant int64_t & pooled_width [[buffer(8)]], \
constant int64_t & sampling_ratio [[buffer(9)]], \
constant bool & aligned [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
constant int64_t & n_stride [[buffer(12)]], \
constant int64_t & c_stride [[buffer(13)]], \
constant int64_t & h_stride [[buffer(14)]], \
constant int64_t & w_stride [[buffer(15)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_pool(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * argmax [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w + 1, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h + 1, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
// Define an empty pooling region to be zero
T maxval = is_empty ? 0 : -FLT_MAX;
// If nothing is pooled, argmax = -1 causes nothing to be backprop'd
integer_t maxidx = -1;
constant T* offset_input =
input + (roi_batch_ind * channels + c) * height * width;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t input_index = h * width + w;
if (offset_input[input_index] > maxval) {
maxval = offset_input[input_index];
maxidx = input_index;
}
}
}
output[index] = maxval;
argmax[index] = maxidx;
}
}
#define REGISTER_ROI_POOL_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_pool_" #DTYPE)]] \
kernel void roi_pool<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * argmax_data [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * argmax_data [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant float & spatial_scale [[buffer(10)]],
constant int64_t & n_stride [[buffer(11)]],
constant int64_t & c_stride [[buffer(12)]],
constant int64_t & h_stride [[buffer(13)]],
constant int64_t & w_stride [[buffer(14)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c = (index / pooled_width / pooled_height) % channels;
integer_t n = index / pooled_width / pooled_height / channels;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
const integer_t output_offset = n * n_stride + c * c_stride;
constant integer_t * argmax_data_offset =
argmax_data + (n * channels + c) * pooled_height * pooled_width;
const integer_t argmax = argmax_data_offset[ph * pooled_width + pw];
const integer_t offset = (roi_batch_ind * channels + c) * height * width;
if (argmax != -1) {
atomic_add_float(grad_input + offset + argmax, static_cast<T>(grad_output[output_offset + ph * h_stride + pw * w_stride]));
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("roi_pool_backward_" #DTYPE)]] \
kernel void roi_pool_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * argmax_data [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant float & spatial_scale [[buffer(10)]], \
constant int64_t & n_stride [[buffer(11)]], \
constant int64_t & c_stride [[buffer(12)]], \
constant int64_t & h_stride [[buffer(13)]], \
constant int64_t & w_stride [[buffer(14)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_align(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c_out = (index / pooled_width / pooled_height) % channels_out;
integer_t n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height);
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T val = bilinear_interpolate(offset_input, height, width, y, x, index);
out_sum += val;
}
}
out_sum /= count;
output[index] = out_sum;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_ALIGN_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_align_" #DTYPE)]] \
kernel void ps_roi_align<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_align_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & sampling_ratio [[buffer(10)]],
constant int64_t & channels_out [[buffer(11)]],
constant float & spatial_scale [[buffer(12)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
// Do not using rounding; this implementation detail is critical
T roi_start_w = offset_rois[1] * spatial_scale - static_cast<T>(0.5);
T roi_start_h = offset_rois[2] * spatial_scale - static_cast<T>(0.5);
T roi_end_w = offset_rois[3] * spatial_scale - static_cast<T>(0.5);
T roi_end_h = offset_rois[4] * spatial_scale - static_cast<T>(0.5);
// Force too small ROIs to be 1x1
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
T bin_size_h = roi_height / static_cast<T>(pooled_height);
T bin_size_w = roi_width / static_cast<T>(pooled_width);
integer_t c_in = channel_mapping[index];
// Do not using floor/ceil; this implementation detail is critical
T hstart = static_cast<T>(ph) * bin_size_h + roi_start_h;
T wstart = static_cast<T>(pw) * bin_size_w + roi_start_w;
const T grad_output_this_bin = grad_output[index];
// We use roi_bin_grid to sample the grid and mimic integral
integer_t roi_bin_grid_h = (sampling_ratio > 0)
? sampling_ratio
: ceil(roi_height / pooled_height); // e.g., = 2
integer_t roi_bin_grid_w =
(sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
const T count = roi_bin_grid_h * roi_bin_grid_w;
const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
for (integer_t iy = 0; iy < roi_bin_grid_h; iy++) {
const T y = hstart +
static_cast<T>(iy + .5f) * bin_size_h /
static_cast<T>(roi_bin_grid_h);
for (integer_t ix = 0; ix < roi_bin_grid_w; ix++) {
const T x = wstart +
static_cast<T>(ix + .5f) * bin_size_w /
static_cast<T>(roi_bin_grid_w);
T w1, w2, w3, w4;
integer_t x_low, x_high, y_low, y_high;
bilinear_interpolate_gradient(
height,
width,
y,
x,
w1,
w2,
w3,
w4,
x_low,
x_high,
y_low,
y_high,
index);
T g1 = grad_output_this_bin * w1 / count;
T g2 = grad_output_this_bin * w2 / count;
T g3 = grad_output_this_bin * w3 / count;
T g4 = grad_output_this_bin * w4 / count;
if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomic_add_float(grad_input + offset + y_low * width + x_low, static_cast<T>(g1));
atomic_add_float(grad_input + offset + y_low * width + x_high, static_cast<T>(g2));
atomic_add_float(grad_input + offset + y_high * width + x_low, static_cast<T>(g3));
atomic_add_float(grad_input + offset + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
} // iy
}
}
#define REGISTER_PS_ROI_ALIGN_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_align_backward_" #DTYPE)]] \
kernel void ps_roi_align_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & sampling_ratio [[buffer(10)]], \
constant int64_t & channels_out [[buffer(11)]], \
constant float & spatial_scale [[buffer(12)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_pool(
constant T * input [[buffer(0)]],
constant T * rois [[buffer(1)]],
device T * output [[buffer(2)]],
device int64_t * channel_mapping [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & channels_out [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, c_out, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t c_out = (index / (pooled_width * pooled_height)) % channels_out;
integer_t n = index / pooled_width / pooled_height / channels_out;
// (n, c_in, ph, pw) is the associated element in the input
integer_t c_in = (c_out * pooled_height + ph) * pooled_width + pw;
// [start, end) interval for spatial sampling
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height - 1));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width - 1));
bool is_empty = (hend <= hstart) || (wend <= wstart);
constant T* offset_input =
input + (roi_batch_ind * channels + c_in) * height * width;
T out_sum = 0;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t input_index = h * width + w;
out_sum += offset_input[input_index];
}
}
T bin_area = (hend - hstart) * (wend - wstart);
output[index] = is_empty ? static_cast<T>(0) : out_sum / bin_area;
channel_mapping[index] = c_in;
}
}
#define REGISTER_PS_ROI_POOL_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_pool_" #DTYPE)]] \
kernel void ps_roi_pool<DTYPE, INT_DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
device DTYPE * output [[buffer(2)]], \
device int64_t * channel_mapping [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
template<typename T, typename integer_t>
kernel void ps_roi_pool_backward(
constant T * grad_output [[buffer(0)]],
constant T * rois [[buffer(1)]],
constant int64_t * channel_mapping [[buffer(2)]],
device T * grad_input [[buffer(3)]],
constant int64_t & output_size [[buffer(4)]],
constant int64_t & channels [[buffer(5)]],
constant int64_t & height [[buffer(6)]],
constant int64_t & width [[buffer(7)]],
constant int64_t & pooled_height [[buffer(8)]],
constant int64_t & pooled_width [[buffer(9)]],
constant int64_t & channels_out [[buffer(10)]],
constant float & spatial_scale [[buffer(11)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint2 tptg [[threads_per_threadgroup]],
uint2 tid2 [[thread_position_in_threadgroup]]){
MPS_1D_KERNEL_LOOP(index, output_size, 1) {
// (n, *, ph, pw) is an element in the pooled output
integer_t pw = index % pooled_width;
integer_t ph = (index / pooled_width) % pooled_height;
integer_t n = index / pooled_width / pooled_height / channels_out;
constant T* offset_rois = rois + n * 5;
integer_t roi_batch_ind = offset_rois[0];
integer_t roi_start_w = round(offset_rois[1] * spatial_scale);
integer_t roi_start_h = round(offset_rois[2] * spatial_scale);
integer_t roi_end_w = round(offset_rois[3] * spatial_scale);
integer_t roi_end_h = round(offset_rois[4] * spatial_scale);
// Force too small ROIs to be 1x1
integer_t roi_width = max(roi_end_w - roi_start_w, static_cast<integer_t>(1));
integer_t roi_height = max(roi_end_h - roi_start_h, static_cast<integer_t>(1));
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
integer_t hstart = static_cast<integer_t>(floor(static_cast<T>(ph) * bin_size_h));
integer_t wstart = static_cast<integer_t>(floor(static_cast<T>(pw) * bin_size_w));
integer_t hend = static_cast<integer_t>(ceil(static_cast<T>(ph + 1) * bin_size_h));
integer_t wend = static_cast<integer_t>(ceil(static_cast<T>(pw + 1) * bin_size_w));
// Add roi offsets and clip to input boundaries
hstart = min(max(hstart + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
hend = min(max(hend + roi_start_h, static_cast<integer_t>(0)), static_cast<integer_t>(height));
wstart = min(max(wstart + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
wend = min(max(wend + roi_start_w, static_cast<integer_t>(0)), static_cast<integer_t>(width));
bool is_empty = (hend <= hstart) || (wend <= wstart);
integer_t c_in = channel_mapping[index];
T bin_area = (hend - hstart) * (wend - wstart);
T diff_val = is_empty ? static_cast<T>(0) : grad_output[index] / bin_area;
const integer_t offset = (roi_batch_ind * channels + c_in) * height * width;
for (integer_t h = hstart; h < hend; ++h) {
for (integer_t w = wstart; w < wend; ++w) {
integer_t grad_input_index = h * width + w;
atomic_add_float(grad_input + offset + grad_input_index, diff_val);
}
}
} // MPS_1D_KERNEL_LOOP
}
#define REGISTER_PS_ROI_POOL_BACKWARD_OP(DTYPE, INT_DTYPE) \
template \
[[host_name("ps_roi_pool_backward_" #DTYPE)]] \
kernel void ps_roi_pool_backward<DTYPE, INT_DTYPE>( \
constant DTYPE * grad_output [[buffer(0)]], \
constant DTYPE * rois [[buffer(1)]], \
constant int64_t * channel_mapping [[buffer(2)]], \
device DTYPE * grad_input [[buffer(3)]], \
constant int64_t & output_size [[buffer(4)]], \
constant int64_t & channels [[buffer(5)]], \
constant int64_t & height [[buffer(6)]], \
constant int64_t & width [[buffer(7)]], \
constant int64_t & pooled_height [[buffer(8)]], \
constant int64_t & pooled_width [[buffer(9)]], \
constant int64_t & channels_out [[buffer(10)]], \
constant float & spatial_scale [[buffer(11)]], \
uint2 tgid [[threadgroup_position_in_grid]], \
uint2 tptg [[threads_per_threadgroup]], \
uint2 tid2 [[thread_position_in_threadgroup]]);
REGISTER_NMS_OP(float);
REGISTER_NMS_OP(half);
REGISTER_ROI_ALIGN_OP(float, int64_t);
REGISTER_ROI_ALIGN_OP(half, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_ROI_POOL_OP(float, int64_t);
REGISTER_ROI_POOL_OP(half, int64_t);
REGISTER_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_ROI_POOL_BACKWARD_OP(half, int64_t);
REGISTER_PS_ROI_ALIGN_OP(float, int64_t);
REGISTER_PS_ROI_ALIGN_OP(half, int64_t);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_ALIGN_BACKWARD_OP(half, int64_t);
REGISTER_PS_ROI_POOL_OP(float, int64_t);
REGISTER_PS_ROI_POOL_OP(half, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(float, int64_t);
REGISTER_PS_ROI_POOL_BACKWARD_OP(half, int64_t);
)VISION_METAL";
static id<MTLLibrary> compileVisionOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> visionLibrary = nil;
if (visionLibrary) {
return visionLibrary;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
visionLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_VISION encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(visionLibrary, "Failed to create metal vision library, error: ", [[error description] UTF8String]);
return visionLibrary;
}
static id<MTLComputePipelineState> visionPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}
NSError* error = nil;
id<MTLLibrary> visionLib = compileVisionOpsLibrary(device);
id<MTLFunction> visionFunc = [visionLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(visionFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:visionFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
psoCache[kernel] = pso;
return pso;
}
} // namespace mps
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
// This should be in sync with `nmsThreadsPerBlock` in the metal kernel.
constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) {
using namespace at::native::mps;
TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor");
TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor");
TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1));
TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D");
TORCH_CHECK(dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0))
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();
int64_t dets_num = dets.size(0);
float iou_threshold_f = static_cast<float>(iou_threshold);
const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock;
at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
id<MTLBuffer> inputBuffer = getMTLBufferStorage(dets_sorted);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(mask);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1);
const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores});
[computeEncoder setComputePipelineState:visionPSO];
[computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0];
[computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1];
[computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2];
[computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > nmsThreadsPerBlock) {
tgSize = nmsThreadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
int64_t num_to_keep = 0;
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
for (int64_t i = 0; i < dets_num; i++) {
int64_t nblock = i / nmsThreadsPerBlock;
int64_t inblock = i % nmsThreadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int64_t j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())});
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ps_roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
TORCH_CHECK(channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int64_t channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong));
int64_t output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs.");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t output_size = grad.numel();
int64_t channels_out = channels / (pooled_height * pooled_width);
at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ps_roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
TORCH_CHECK(channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int64_t channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs.");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "ps_roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t channels_out = channels / (pooled_height * pooled_width);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
at::Tensor roi_align_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
int64_t output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) {
return output;
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return output;
}
at::Tensor roi_align_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs.");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong));
int64_t output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) {
return std::make_tuple(output, argmax);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, argmax);
}
at::Tensor roi_pool_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs.");
TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3};
at::CheckedFrom c = "roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel));
}
} // namespace ops
} // namespace vision
...@@ -158,12 +158,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling ...@@ -158,12 +158,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
y = ( y = (
from_K(roi_start_h) from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h) + ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY] ) # [K, PH, IY]
x = ( x = (
from_K(roi_start_w) from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w) + pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX] ) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
...@@ -232,7 +232,7 @@ def roi_align( ...@@ -232,7 +232,7 @@ def roi_align(
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops() _assert_has_ops()
return torch.ops.torchvision.roi_align( return torch.ops.torchvision.roi_align(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment