Unverified Commit 8a0b491e authored by Chunyuan WU's avatar Chunyuan WU Committed by GitHub
Browse files

Add autocast for nms, roi_align on CPU (#8049)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent b80bdb75
......@@ -129,8 +129,10 @@ def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchvision", "csrc")
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob(
os.path.join(extensions_dir, "ops", "*.cpp")
main_file = (
glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "*.cpp"))
+ glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
)
source_cpu = (
glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp"))
......@@ -184,8 +186,6 @@ def get_extensions():
else:
source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu"))
source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp"))
sources = main_file + source_cpu
extension = CppExtension
......
......@@ -131,6 +131,8 @@ class RoIOpTester(ABC):
tol = 5e-3
else:
tol = 4e-3
elif x_dtype == torch.bfloat16:
tol = 5e-3
pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS operations.
......@@ -504,6 +506,21 @@ class TestRoIAlign(RoIOpTester):
rois_dtype=rois_dtype,
)
@pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.bfloat16))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cpu.amp.autocast():
self.test_forward(
torch.device("cpu"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
)
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
......@@ -808,6 +825,15 @@ class TestNMS:
with torch.cuda.amp.autocast():
self.test_nms_gpu(iou=iou, dtype=dtype, device="cuda")
@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
@pytest.mark.parametrize("dtype", (torch.float, torch.bfloat16))
def test_autocast_cpu(self, iou, dtype):
boxes, scores = self._create_tensors_with_iou(1000, iou)
with torch.cpu.amp.autocast():
keep_ref_float = ops.nms(boxes.to(dtype).float(), scores.to(dtype).float(), iou)
keep_dtype = ops.nms(boxes.to(dtype), scores.to(dtype), iou)
torch.testing.assert_close(keep_ref_float, keep_dtype)
@pytest.mark.parametrize(
"device",
(
......
......@@ -9,21 +9,33 @@ namespace ops {
namespace {
template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return nms(
at::autocast::cached_cast(at::kFloat, dets),
at::autocast::cached_cast(at::kFloat, scores),
at::autocast::cached_cast(at::kFloat, dets, device_type),
at::autocast::cached_cast(at::kFloat, scores, device_type),
iou_threshold);
}
} // namespace
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_autocast));
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::Autocast, c10::DeviceType::CUDA>)));
}
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::nms"),
TORCH_FN(
(nms_autocast<c10::DispatchKey::AutocastCPU, c10::DeviceType::CPU>)));
}
} // namespace ops
......
......@@ -9,6 +9,7 @@ namespace ops {
namespace {
template <c10::DispatchKey autocast_key, c10::DeviceType device_type>
at::Tensor roi_align_autocast(
const at::Tensor& input,
const at::Tensor& rois,
......@@ -17,10 +18,10 @@ at::Tensor roi_align_autocast(
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
c10::impl::ExcludeDispatchKeyGuard no_autocast(autocast_key);
return roi_align(
at::autocast::cached_cast(at::kFloat, input),
at::autocast::cached_cast(at::kFloat, rois),
at::autocast::cached_cast(at::kFloat, input, device_type),
at::autocast::cached_cast(at::kFloat, rois, device_type),
spatial_scale,
pooled_height,
pooled_width,
......@@ -34,7 +35,17 @@ at::Tensor roi_align_autocast(
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN(roi_align_autocast));
TORCH_FN((roi_align_autocast<
c10::DispatchKey::Autocast,
c10::DeviceType::CUDA>)));
}
TORCH_LIBRARY_IMPL(torchvision, AutocastCPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::roi_align"),
TORCH_FN((roi_align_autocast<
c10::DispatchKey::AutocastCPU,
c10::DeviceType::CPU>)));
}
} // namespace ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment