Unverified Commit 8a0b491e authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

Add autocast for nms, roi_align on CPU (#8049)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent b80bdb75
...@@ -129,8 +129,10 @@ def get_extensions(): ...@@ -129,8 +129,10 @@ def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc") extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob( main_file = (
os.path.join(extensions_dir, "ops", "*.cpp") glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
) )
source_cpu = ( source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
...@@ -184,8 +186,6 @@ def get_extensions(): ...@@ -184,8 +186,6 @@ def get_extensions():
else: else:
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
sources = main_file + source_cpu sources = main_file + source_cpu
extension = CppExtension extension = CppExtension
......
...@@ -131,6 +131,8 @@ class RoIOpTester(ABC): ...@@ -131,6 +131,8 @@ class RoIOpTester(ABC):
tol = 5e-3 tol = 5e-3
else: else:
tol = 4e-3 tol = 4e-3
elif x_dtype == torch.bfloat16:
tol = 5e-3
pool_size = 5 pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations. # n_channels % (pool_size ** 2) == 0 required for PS operations.
...@@ -504,6 +506,21 @@ class TestRoIAlign(RoIOpTester): ...@@ -504,6 +506,21 @@ class TestRoIAlign(RoIOpTester):
rois_dtype=rois_dtype, rois_dtype=rois_dtype,
) )
@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cpu.amp.autocast():
self.test_forward(
torch.device("cpu"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps()) @pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
...@@ -808,6 +825,15 @@ class TestNMS: ...@@ -808,6 +825,15 @@ class TestNMS:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda") self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, iou, dtype):
boxes, scores = self._create_tensors_with_iou(1000, iou)
with torch.cpu.amp.autocast():
keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
torch.testing.assert_close(keep_ref_float, keep_dtype)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
( (
......
...@@ -9,21 +9,33 @@ namespace ops { ...@@ -9,21 +9,33 @@ namespace ops {
namespace { namespace {
template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor nms_autocast( at::Tensor nms_autocast(
const at::Tensor& dets, const at::Tensor& dets,
const at::Tensor& scores, const at::Tensor& scores,
double iou_threshold) { double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return nms( return nms(
at::autocast::cached_cast(at::kFloat, dets), at::autocast::cached_cast(at::kFloat, dets, device_type),
at::autocast::cached_cast(at::kFloat, scores), at::autocast::cached_cast(at::kFloat, scores, device_type),
iou_threshold); iou_threshold);
} }
} // namespace } // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast)); m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
} }
} // namespace ops } // namespace ops
......
...@@ -9,6 +9,7 @@ namespace ops { ...@@ -9,6 +9,7 @@ namespace ops {
namespace { namespace {
template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor roi_align_autocast( at::Tensor roi_align_autocast(
const at::Tensor& input, const at::Tensor& input,
const at::Tensor& rois, const at::Tensor& rois,
...@@ -17,10 +18,10 @@ at::Tensor roi_align_autocast( ...@@ -17,10 +18,10 @@ at::Tensor roi_align_autocast(
int64_t pooled_width, int64_t pooled_width,
int64_t sampling_ratio, int64_t sampling_ratio,
bool aligned) { bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return roi_align( return roi_align(
at::autocast::cached_cast(at::kFloat, input), at::autocast::cached_cast(at::kFloat, input, device_type),
at::autocast::cached_cast(at::kFloat, rois), at::autocast::cached_cast(at::kFloat, rois, device_type),
spatial_scale, spatial_scale,
pooled_height, pooled_height,
pooled_width, pooled_width,
...@@ -34,7 +35,17 @@ at::Tensor roi_align_autocast( ...@@ -34,7 +35,17 @@ at::Tensor roi_align_autocast(
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) { TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl( m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"), TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast)); TORCH_FN((roi_align_autocast<
c10::DispatchKey::Autocast,
c10::DeviceType::CUDA>)));
}
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<
c10::DispatchKey::AutocastCPU,
c10::DeviceType::CPU>)));
} }
} // namespace ops } // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment