Unverified Commit 16d62e30 authored by Li-Huai (Allan) Lin's avatar Li-Huai (Allan) Lin Committed by GitHub
Browse files

Add MPS kernels for nms and roi ops (#7643)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent f524cd3a
...@@ -48,7 +48,7 @@ except ImportError: ...@@ -48,7 +48,7 @@ except ImportError:
DEVNULL = open(os.devnull, "wb") DEVNULL = open(os.devnull, "wb")
DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu,mm"
class ExitStatus: class ExitStatus:
......
...@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17) ...@@ -4,6 +4,7 @@ set(CMAKE_CXX_STANDARD 17)
file(STRINGS version.txt TORCHVISION_VERSION) file(STRINGS version.txt TORCHVISION_VERSION)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_MPS "Enable MPS support" OFF)
option(WITH_PNG "Enable features requiring LibPNG." ON) option(WITH_PNG "Enable features requiring LibPNG." ON)
option(WITH_JPEG "Enable features requiring LibJPEG." ON) option(WITH_JPEG "Enable features requiring LibJPEG." ON)
option(USE_PYTHON "Link to Python when building" OFF) option(USE_PYTHON "Link to Python when building" OFF)
...@@ -15,6 +16,11 @@ if(WITH_CUDA) ...@@ -15,6 +16,11 @@ if(WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif() endif()
if(WITH_MPS)
enable_language(OBJC OBJCXX)
add_definitions(-DWITH_MPS)
endif()
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
if (WITH_PNG) if (WITH_PNG)
...@@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP ...@@ -79,6 +85,9 @@ list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCP
if(WITH_CUDA) if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif() endif()
if(WITH_MPS)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/mps)
endif()
FOREACH(DIR ${ALLOW_LISTED}) FOREACH(DIR ${ALLOW_LISTED})
file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*) file(GLOB ALL_SOURCES ${ALL_SOURCES} ${DIR}/*.*)
......
...@@ -137,10 +137,13 @@ def get_extensions(): ...@@ -137,10 +137,13 @@ def get_extensions():
+ glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp"))
) )
source_mps = glob.glob(os.path.join(extensions_dir, "ops", "mps", "*.mm"))
print("Compiling extensions with following flags:") print("Compiling extensions with following flags:")
force_cuda = os.getenv("FORCE_CUDA", "0") == "1" force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
print(f" FORCE_CUDA: {force_cuda}") print(f" FORCE_CUDA: {force_cuda}")
force_mps = os.getenv("FORCE_MPS", "0") == "1"
print(f" FORCE_MPS: {force_mps}")
debug_mode = os.getenv("DEBUG", "0") == "1" debug_mode = os.getenv("DEBUG", "0") == "1"
print(f" DEBUG: {debug_mode}") print(f" DEBUG: {debug_mode}")
use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1" use_png = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
...@@ -202,6 +205,8 @@ def get_extensions(): ...@@ -202,6 +205,8 @@ def get_extensions():
define_macros += [("WITH_HIP", None)] define_macros += [("WITH_HIP", None)]
nvcc_flags = [] nvcc_flags = []
extra_compile_args["nvcc"] = nvcc_flags extra_compile_args["nvcc"] = nvcc_flags
elif torch.backends.mps.is_available() or force_mps:
sources += source_mps
if sys.platform == "win32": if sys.platform == "win32":
define_macros += [("torchvision_EXPORTS", None)] define_macros += [("torchvision_EXPORTS", None)]
......
...@@ -34,6 +34,7 @@ IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS ...@@ -34,6 +34,7 @@ IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
...@@ -130,12 +131,22 @@ def cpu_and_cuda(): ...@@ -130,12 +131,22 @@ def cpu_and_cuda():
return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda))
def cpu_and_cuda_and_mps():
return cpu_and_cuda() + (pytest.param("mps", marks=pytest.mark.needs_mps),)
def needs_cuda(test_func): def needs_cuda(test_func):
import pytest # noqa import pytest # noqa
return pytest.mark.needs_cuda(test_func) return pytest.mark.needs_cuda(test_func)
def needs_mps(test_func):
import pytest # noqa
return pytest.mark.needs_mps(test_func)
def _create_data(height=3, width=3, channels=3, device="cpu"): def _create_data(height=3, width=3, channels=3, device="cpu"):
# TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
......
...@@ -8,12 +8,20 @@ import torchvision ...@@ -8,12 +8,20 @@ import torchvision
torchvision.disable_beta_transforms_warning() torchvision.disable_beta_transforms_warning()
from common_utils import CUDA_NOT_AVAILABLE_MSG, IN_FBCODE, IN_OSS_CI, IN_RE_WORKER, OSS_CI_GPU_NO_CUDA_MSG from common_utils import (
CUDA_NOT_AVAILABLE_MSG,
IN_FBCODE,
IN_OSS_CI,
IN_RE_WORKER,
MPS_NOT_AVAILABLE_MSG,
OSS_CI_GPU_NO_CUDA_MSG,
)
def pytest_configure(config): def pytest_configure(config):
# register an additional marker (see pytest_collection_modifyitems) # register an additional marker (see pytest_collection_modifyitems)
config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device")
config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device")
config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected")
...@@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items): ...@@ -37,12 +45,16 @@ def pytest_collection_modifyitems(items):
# the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark,
# and the ones with device == 'cpu' won't have the mark. # and the ones with device == 'cpu' won't have the mark.
needs_cuda = item.get_closest_marker("needs_cuda") is not None needs_cuda = item.get_closest_marker("needs_cuda") is not None
needs_mps = item.get_closest_marker("needs_mps") is not None
if needs_cuda and not torch.cuda.is_available(): if needs_cuda and not torch.cuda.is_available():
# In general, we skip cuda tests on machines without a GPU # In general, we skip cuda tests on machines without a GPU
# There are special cases though, see below # There are special cases though, see below
item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG))
if needs_mps and not torch.backends.mps.is_available():
item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG))
if IN_FBCODE: if IN_FBCODE:
# fbcode doesn't like skipping tests, so instead we just don't collect the test # fbcode doesn't like skipping tests, so instead we just don't collect the test
# so that they don't even "exist", hence the continue statements. # so that they don't even "exist", hence the continue statements.
...@@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items): ...@@ -54,6 +66,9 @@ def pytest_collection_modifyitems(items):
# TODO: something more robust would be to do that only in a sandcastle instance, # TODO: something more robust would be to do that only in a sandcastle instance,
# so that we can still see the test being skipped when testing locally from a devvm # so that we can still see the test being skipped when testing locally from a devvm
continue continue
if needs_mps and not torch.backends.mps.is_available():
# Same as above, but for MPS
continue
elif IN_OSS_CI: elif IN_OSS_CI:
# Here we're not in fbcode, so we can safely collect and skip tests. # Here we're not in fbcode, so we can safely collect and skip tests.
if not needs_cuda and torch.cuda.is_available(): if not needs_cuda and torch.cuda.is_available():
......
...@@ -10,7 +10,7 @@ import pytest ...@@ -10,7 +10,7 @@ import pytest
import torch import torch
import torch.fx import torch.fx
import torch.nn.functional as F import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_cuda, needs_cuda from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image from PIL import Image
from torch import nn, Tensor from torch import nn, Tensor
from torch.autograd import gradcheck from torch.autograd import gradcheck
...@@ -96,12 +96,33 @@ class PoolWrapper(nn.Module): ...@@ -96,12 +96,33 @@ class PoolWrapper(nn.Module):
class RoIOpTester(ABC): class RoIOpTester(ABC):
dtype = torch.float64 dtype = torch.float64
mps_dtype = torch.float32
mps_backward_atol = 2e-2
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs): @pytest.mark.parametrize(
x_dtype = self.dtype if x_dtype is None else x_dtype "x_dtype",
rois_dtype = self.dtype if rois_dtype is None else rois_dtype (
torch.float16,
torch.float32,
torch.float64,
),
ids=str,
)
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
if device == "mps" and x_dtype is torch.float64:
pytest.skip("MPS does not support float64")
rois_dtype = x_dtype if rois_dtype is None else rois_dtype
tol = 1e-5
if x_dtype is torch.half:
if device == "mps":
tol = 5e-3
else:
tol = 4e-3
pool_size = 5 pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations. # n_channels % (pool_size ** 2) == 0 required for PS operations.
n_channels = 2 * (pool_size**2) n_channels = 2 * (pool_size**2)
...@@ -120,10 +141,9 @@ class RoIOpTester(ABC): ...@@ -120,10 +141,9 @@ class RoIOpTester(ABC):
# the following should be true whether we're running an autocast test or not. # the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype assert y.dtype == x.dtype
gt_y = self.expected_fn( gt_y = self.expected_fn(
x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=x_dtype, **kwargs
) )
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol)
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
...@@ -155,16 +175,19 @@ class RoIOpTester(ABC): ...@@ -155,16 +175,19 @@ class RoIOpTester(ABC):
torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol) torch.testing.assert_close(output_gt, output_fx, rtol=tol, atol=tol)
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous, deterministic=False): def test_backward(self, seed, device, contiguous, deterministic=False):
atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
pool_size = 2 pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=dtype, device=device, requires_grad=True)
if not contiguous: if not contiguous:
x = x.permute(0, 1, 3, 2) x = x.permute(0, 1, 3, 2)
rois = torch.tensor( rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=dtype, device=device # format is (xyxy)
) )
def func(z): def func(z):
...@@ -173,9 +196,25 @@ class RoIOpTester(ABC): ...@@ -173,9 +196,25 @@ class RoIOpTester(ABC):
script_func = self.get_script_fn(rois, pool_size) script_func = self.get_script_fn(rois, pool_size)
with DeterministicGuard(deterministic): with DeterministicGuard(deterministic):
gradcheck(func, (x,)) gradcheck(func, (x,), atol=atol)
gradcheck(script_func, (x,), atol=atol)
gradcheck(script_func, (x,)) @needs_mps
def test_mps_error_inputs(self):
pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=torch.float16, device="mps", requires_grad=True)
rois = torch.tensor(
[[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=torch.float16, device="mps" # format is (xyxy)
)
def func(z):
return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1)
with pytest.raises(
RuntimeError, match="MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
):
gradcheck(func, (x,))
@needs_cuda @needs_cuda
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
...@@ -271,6 +310,8 @@ class TestRoiPool(RoIOpTester): ...@@ -271,6 +310,8 @@ class TestRoiPool(RoIOpTester):
class TestPSRoIPool(RoIOpTester): class TestPSRoIPool(RoIOpTester):
mps_backward_atol = 5e-2
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois) return ops.PSRoIPool((pool_h, pool_w), 1)(x, rois)
...@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False): ...@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):
class TestRoIAlign(RoIOpTester): class TestRoIAlign(RoIOpTester):
mps_backward_atol = 6e-2
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs):
return ops.RoIAlign( return ops.RoIAlign(
(pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned
...@@ -418,10 +461,11 @@ class TestRoIAlign(RoIOpTester): ...@@ -418,10 +461,11 @@ class TestRoIAlign(RoIOpTester):
self._helper_boxes_shape(ops.roi_align) self._helper_boxes_shape(ops.roi_align)
@pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("x_dtype", (torch.float16, torch.float32, torch.float64), ids=str)
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False)) @pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None): def test_forward(self, device, contiguous, deterministic, aligned, x_dtype, rois_dtype=None):
if deterministic and device == "cpu": if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest") pytest.skip("cpu is always deterministic, don't retest")
super().test_forward( super().test_forward(
...@@ -450,7 +494,7 @@ class TestRoIAlign(RoIOpTester): ...@@ -450,7 +494,7 @@ class TestRoIAlign(RoIOpTester):
) )
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False)) @pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic): def test_backward(self, seed, device, contiguous, deterministic):
...@@ -537,6 +581,8 @@ class TestRoIAlign(RoIOpTester): ...@@ -537,6 +581,8 @@ class TestRoIAlign(RoIOpTester):
class TestPSRoIAlign(RoIOpTester): class TestPSRoIAlign(RoIOpTester):
mps_backward_atol = 5e-2
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois)
...@@ -705,21 +751,28 @@ class TestNMS: ...@@ -705,21 +751,28 @@ class TestNMS:
torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou)) torch.testing.assert_close(qkeep, keep, msg=err_msg.format(iou))
@needs_cuda @pytest.mark.parametrize(
"device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_cuda(self, iou, dtype=torch.float64): def test_nms_gpu(self, iou, device, dtype=torch.float64):
dtype = torch.float32 if device == "mps" else dtype
tol = 1e-3 if dtype is torch.half else 1e-5 tol = 1e-3 if dtype is torch.half else 1e-5
err_msg = "NMS incompatible between CPU and CUDA for IoU={}" err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
boxes, scores = self._create_tensors_with_iou(1000, iou) boxes, scores = self._create_tensors_with_iou(1000, iou)
r_cpu = ops.nms(boxes, scores, iou) r_cpu = ops.nms(boxes, scores, iou)
r_cuda = ops.nms(boxes.cuda(), scores.cuda(), iou) r_gpu = ops.nms(boxes.to(device), scores.to(device), iou)
is_eq = torch.allclose(r_cpu, r_cuda.cpu()) is_eq = torch.allclose(r_cpu, r_gpu.cpu())
if not is_eq: if not is_eq:
# if the indices are not the same, ensure that it's because the scores # if the indices are not the same, ensure that it's because the scores
# are duplicate # are duplicate
is_eq = torch.allclose(scores[r_cpu], scores[r_cuda.cpu()], rtol=tol, atol=tol) is_eq = torch.allclose(scores[r_cpu], scores[r_gpu.cpu()], rtol=tol, atol=tol)
assert is_eq, err_msg.format(iou) assert is_eq, err_msg.format(iou)
@needs_cuda @needs_cuda
...@@ -727,18 +780,24 @@ class TestNMS: ...@@ -727,18 +780,24 @@ class TestNMS:
@pytest.mark.parametrize("dtype", (torch.float, torch.half)) @pytest.mark.parametrize("dtype", (torch.float, torch.half))
def test_autocast(self, iou, dtype): def test_autocast(self, iou, dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_nms_cuda(iou=iou, dtype=dtype) self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
@needs_cuda @pytest.mark.parametrize(
def test_nms_cuda_float16(self): "device",
(
pytest.param("cuda", marks=pytest.mark.needs_cuda),
pytest.param("mps", marks=pytest.mark.needs_mps),
),
)
def test_nms_float16(self, device):
boxes = torch.tensor( boxes = torch.tensor(
[ [
[285.3538, 185.5758, 1193.5110, 851.4551], [285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669], [285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019], [279.2440, 197.9812, 1189.4746, 849.2019],
] ]
).cuda() ).to(device)
scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() scores = torch.tensor([0.6370, 0.7569, 0.3966]).to(device)
iou_thres = 0.2 iou_thres = 0.2
keep32 = ops.nms(boxes, scores, iou_thres) keep32 = ops.nms(boxes, scores, iou_thres)
......
...@@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl( ...@@ -11,8 +11,8 @@ at::Tensor nms_kernel_impl(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
double iou_threshold) { double iou_threshold) {
TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor");
TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor");
TORCH_CHECK( TORCH_CHECK(
dets.scalar_type() == scores.scalar_type(), dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores"); "dets should have the same type as scores");
......
constexpr int threadsPerBlock = 512;
template <typename T>
constexpr inline T ceil_div(T n, T m) {
return (n + m - 1) / m;
}
This diff is collapsed.
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
// This should be in sync with `nmsThreadsPerBlock` in the metal kernel.
constexpr int64_t nmsThreadsPerBlock = sizeof(uint64_t) * 8;
at::Tensor nms_kernel(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) {
using namespace at::native::mps;
TORCH_CHECK(dets.is_mps(), "dets must be a MPS tensor");
TORCH_CHECK(scores.is_mps(), "scores must be a MPS tensor");
TORCH_CHECK(dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(dets.size(1) == 4, "boxes should have 4 elements in dimension 1, got ", dets.size(1));
TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", scores.dim(), "D");
TORCH_CHECK(dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0))
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();
int64_t dets_num = dets.size(0);
float iou_threshold_f = static_cast<float>(iou_threshold);
const int col_blocks = (dets_num + nmsThreadsPerBlock - 1) / nmsThreadsPerBlock;
at::Tensor mask = at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
id<MTLBuffer> inputBuffer = getMTLBufferStorage(dets_sorted);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(mask);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(col_blocks, col_blocks, 1);
const std::string kernel = "nms_" + scalarToMetalTypeString(dets_sorted.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {dets, scores});
[computeEncoder setComputePipelineState:visionPSO];
[computeEncoder setBuffer:inputBuffer offset:dets_sorted.storage_offset() * dets_sorted.element_size() atIndex:0];
[computeEncoder setBuffer:outputBuffer offset:mask.storage_offset() * mask.element_size() atIndex:1];
[computeEncoder setBytes:&dets_num length:sizeof(int64_t) atIndex:2];
[computeEncoder setBytes:&iou_threshold_f length:sizeof(float) atIndex:3];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > nmsThreadsPerBlock) {
tgSize = nmsThreadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
int64_t num_to_keep = 0;
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep = at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
for (int64_t i = 0; i < dets_num; i++) {
int64_t nblock = i / nmsThreadsPerBlock;
int64_t inblock = i % nmsThreadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int64_t j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(order_t.device(), keep.scalar_type())});
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ps_roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
TORCH_CHECK(channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int64_t channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong));
int64_t output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor ps_roi_align_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_align backward with float16 inputs.");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "ps_roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t output_size = grad.numel();
int64_t channels_out = channels / (pooled_height * pooled_width);
at::globalContext().alertNotDeterministic("ps_roi_align_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:12];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_align"), TORCH_FN(ps_roi_align_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_align_backward"), TORCH_FN(ps_roi_align_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "ps_roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
TORCH_CHECK(channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width");
int64_t channels_out = channels / (pooled_height * pooled_width);
auto output = at::zeros({num_rois, channels_out, pooled_height, pooled_width}, input.options());
auto channel_mapping = at::zeros(output.sizes(), input.options().dtype(at::kLong));
auto output_size = output.numel();
if (output_size == 0) {
return std::make_tuple(output, channel_mapping);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_pool_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, channel_mapping);
}
at::Tensor ps_roi_pool_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support ps_roi_pool backward with float16 inputs.");
TORCH_CHECK(channel_mapping.is_mps(), "channel_mapping must be a MPS tensor");
at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2},
channel_mapping_t{channel_mapping, "channel_mapping", 3};
at::CheckedFrom c = "ps_roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
auto num_rois = rois.size(0);
auto grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t channels_out = channels / (pooled_height * pooled_width);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("ps_roi_pool_backward_kernel");
auto grad_ = grad.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> channelMappingBuffer = getMTLBufferStorage(channel_mapping);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "ps_roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad_, rois_, channel_mapping});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad_.storage_offset() * grad_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:channelMappingBuffer
offset:channel_mapping.storage_offset() * channel_mapping.element_size()
atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&channels_out length:sizeof(int64_t) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::ps_roi_pool"), TORCH_FN(ps_roi_pool_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_ps_roi_pool_backward"), TORCH_FN(ps_roi_pool_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
at::Tensor roi_align_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
int64_t output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) {
return output;
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_align_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return output;
}
at::Tensor roi_align_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width,
int64_t sampling_ratio,
bool aligned) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_align backward with float16 inputs.");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_align_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("roi_align_backward_kernel");
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_align_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:2];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:3];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&sampling_ratio length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&aligned length:sizeof(bool) atIndex:10];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:11];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:14];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:15];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_FN(roi_align_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_align_backward"), TORCH_FN(roi_align_backward_kernel));
}
} // namespace ops
} // namespace vision
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_helpers.h"
#include "mps_kernels.h"
namespace vision {
namespace ops {
namespace {
std::tuple<at::Tensor, at::Tensor> roi_pool_forward_kernel(const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width) {
using namespace at::native::mps;
TORCH_CHECK(input.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(rois.size(1) == 5, "rois must have shape as Tensor[K, 5]");
at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2};
at::CheckedFrom c = "roi_pool_forward_kernel";
at::checkAllSameGPU(c, {input_t, rois_t});
at::checkAllSameType(c, {input_t, rois_t});
int64_t num_rois = rois.size(0);
int64_t channels = input.size(1);
int64_t height = input.size(2);
int64_t width = input.size(3);
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor output = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options());
at::Tensor argmax = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kLong));
int64_t output_size = num_rois * pooled_height * pooled_width * channels;
if (output.numel() == 0) {
return std::make_tuple(output, argmax);
}
auto input_ = input.contiguous();
auto rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(output_size), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_pool_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {input_, rois_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:input_.storage_offset() * input_.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:output.storage_offset() * output.element_size() atIndex:2];
[computeEncoder setBuffer:argmaxBuffer offset:argmax.storage_offset() * argmax.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return std::make_tuple(output, argmax);
}
at::Tensor roi_pool_backward_kernel(const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& argmax,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
using namespace at::native::mps;
TORCH_CHECK(grad.is_mps(), "grad must be a MPS tensor");
TORCH_CHECK(rois.is_mps(), "rois must be a MPS tensor");
TORCH_CHECK(grad.scalar_type() != at::kHalf, "MPS does not support roi_pool backward with float16 inputs.");
TORCH_CHECK(argmax.is_mps(), "argmax must be a MPS tensor");
at::TensorArg grad_t{grad, "input", 1}, rois_t{rois, "rois", 2}, argmax_t{argmax, "argmax", 3};
at::CheckedFrom c = "roi_pool_backward_kernel";
at::checkAllSameGPU(c, {grad_t, rois_t, argmax_t});
at::checkAllSameType(c, {grad_t, rois_t});
float spatial_scale_f = static_cast<float>(spatial_scale);
at::Tensor grad_input = at::zeros({batch_size, channels, height, width}, grad.options());
if (grad.numel() == 0) {
return grad_input;
}
int64_t n_stride = grad.stride(0);
int64_t c_stride = grad.stride(1);
int64_t h_stride = grad.stride(2);
int64_t w_stride = grad.stride(3);
int64_t output_size = grad.numel();
at::globalContext().alertNotDeterministic("roi_pool_backward_kernel");
auto argmax_ = argmax.contiguous(), rois_ = rois.contiguous();
id<MTLBuffer> inputBuffer = getMTLBufferStorage(grad);
id<MTLBuffer> roisBuffer = getMTLBufferStorage(rois_);
id<MTLBuffer> argmaxBuffer = getMTLBufferStorage(argmax_);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(grad_input);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize threadgroupsPerGrid = MTLSizeMake(
std::min(ceil_div(static_cast<int64_t>(grad.numel()), static_cast<int64_t>(512)), static_cast<int64_t>(4096)),
1,
1);
const std::string kernel = "roi_pool_backward_" + scalarToMetalTypeString(grad.scalar_type());
id<MTLComputePipelineState> visionPSO = mps::visionPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(visionPSO, kernel, {grad, rois_, argmax_});
[computeEncoder setComputePipelineState:visionPSO];
// [N, C, H, W]
[computeEncoder setBuffer:inputBuffer offset:grad.storage_offset() * grad.element_size() atIndex:0];
[computeEncoder setBuffer:roisBuffer offset:rois_.storage_offset() * rois_.element_size() atIndex:1];
[computeEncoder setBuffer:argmaxBuffer offset:argmax_.storage_offset() * argmax_.element_size() atIndex:2];
[computeEncoder setBuffer:outputBuffer offset:grad_input.storage_offset() * grad_input.element_size() atIndex:3];
[computeEncoder setBytes:&output_size length:sizeof(int64_t) atIndex:4];
[computeEncoder setBytes:&channels length:sizeof(int64_t) atIndex:5];
[computeEncoder setBytes:&height length:sizeof(int64_t) atIndex:6];
[computeEncoder setBytes:&width length:sizeof(int64_t) atIndex:7];
[computeEncoder setBytes:&pooled_height length:sizeof(int64_t) atIndex:8];
[computeEncoder setBytes:&pooled_width length:sizeof(int64_t) atIndex:9];
[computeEncoder setBytes:&spatial_scale_f length:sizeof(float) atIndex:10];
[computeEncoder setBytes:&n_stride length:sizeof(int64_t) atIndex:11];
[computeEncoder setBytes:&c_stride length:sizeof(int64_t) atIndex:12];
[computeEncoder setBytes:&h_stride length:sizeof(int64_t) atIndex:13];
[computeEncoder setBytes:&w_stride length:sizeof(int64_t) atIndex:14];
// A threadGroup is equivalent to a cuda's block.
NSUInteger tgSize = visionPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > threadsPerBlock) {
tgSize = threadsPerBlock;
}
MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
getMPSProfiler().endProfileKernel(visionPSO);
}
});
return grad_input;
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::roi_pool"), TORCH_FN(roi_pool_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("torchvision::_roi_pool_backward"), TORCH_FN(roi_pool_backward_kernel));
}
} // namespace ops
} // namespace vision
...@@ -158,12 +158,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling ...@@ -158,12 +158,12 @@ def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling
y = ( y = (
from_K(roi_start_h) from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h) + ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h) + (iy[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY] ) # [K, PH, IY]
x = ( x = (
from_K(roi_start_w) from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w) + pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w) + (ix[None, None, :] + 0.5).to(input.dtype) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX] ) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX] val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
...@@ -232,7 +232,7 @@ def roi_align( ...@@ -232,7 +232,7 @@ def roi_align(
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda): if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops() _assert_has_ops()
return torch.ops.torchvision.roi_align( return torch.ops.torchvision.roi_align(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment