Unverified Commit fc838add authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Add deterministic, pure-Python roi_align implementation (#7587)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
parent a5579189
...@@ -19,6 +19,22 @@ from torchvision import models, ops ...@@ -19,6 +19,22 @@ from torchvision import models, ops
from torchvision.models.feature_extraction import get_graph_node_names from torchvision.models.feature_extraction import get_graph_node_names
# Context manager for setting deterministic flag and automatically
# resetting it to its original value
class DeterministicGuard:
def __init__(self, deterministic, *, warn_only=False):
self.deterministic = deterministic
self.warn_only = warn_only
def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
torch.use_deterministic_algorithms(self.deterministic, warn_only=self.warn_only)
def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(self.deterministic_restore, warn_only=self.warn_only_restore)
class RoIOpTesterModuleWrapper(nn.Module): class RoIOpTesterModuleWrapper(nn.Module):
def __init__(self, obj): def __init__(self, obj):
super().__init__() super().__init__()
...@@ -83,7 +99,7 @@ class RoIOpTester(ABC): ...@@ -83,7 +99,7 @@ class RoIOpTester(ABC):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs): def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, deterministic=False, **kwargs):
x_dtype = self.dtype if x_dtype is None else x_dtype x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype rois_dtype = self.dtype if rois_dtype is None else rois_dtype
pool_size = 5 pool_size = 5
...@@ -99,7 +115,8 @@ class RoIOpTester(ABC): ...@@ -99,7 +115,8 @@ class RoIOpTester(ABC):
) )
pool_h, pool_w = pool_size, pool_size pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) with DeterministicGuard(deterministic):
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
# the following should be true whether we're running an autocast test or not. # the following should be true whether we're running an autocast test or not.
assert y.dtype == x.dtype assert y.dtype == x.dtype
gt_y = self.expected_fn( gt_y = self.expected_fn(
...@@ -140,7 +157,7 @@ class RoIOpTester(ABC): ...@@ -140,7 +157,7 @@ class RoIOpTester(ABC):
@pytest.mark.parametrize("seed", range(10)) @pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous): def test_backward(self, seed, device, contiguous, deterministic=False):
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
pool_size = 2 pool_size = 2
x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) x = torch.rand(1, 2 * (pool_size**2), 5, 5, dtype=self.dtype, device=device, requires_grad=True)
...@@ -155,7 +172,9 @@ class RoIOpTester(ABC): ...@@ -155,7 +172,9 @@ class RoIOpTester(ABC):
script_func = self.get_script_fn(rois, pool_size) script_func = self.get_script_fn(rois, pool_size)
gradcheck(func, (x,)) with DeterministicGuard(deterministic):
gradcheck(func, (x,))
gradcheck(script_func, (x,)) gradcheck(script_func, (x,))
@needs_cuda @needs_cuda
...@@ -384,7 +403,6 @@ class TestRoIAlign(RoIOpTester): ...@@ -384,7 +403,6 @@ class TestRoIAlign(RoIOpTester):
grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w)) grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(bin_w))
for channel in range(0, n_channels): for channel in range(0, n_channels):
val = 0 val = 0
for iy in range(0, grid_h): for iy in range(0, grid_h):
y = start_h + (iy + 0.5) * bin_h / grid_h y = start_h + (iy + 0.5) * bin_h / grid_h
...@@ -402,21 +420,44 @@ class TestRoIAlign(RoIOpTester): ...@@ -402,21 +420,44 @@ class TestRoIAlign(RoIOpTester):
@pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("contiguous", (True, False))
def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None): @pytest.mark.parametrize("deterministic", (True, False))
def test_forward(self, device, contiguous, deterministic, aligned, x_dtype=None, rois_dtype=None):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_forward( super().test_forward(
device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned device=device,
contiguous=contiguous,
deterministic=deterministic,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
aligned=aligned,
) )
@needs_cuda @needs_cuda
@pytest.mark.parametrize("aligned", (True, False)) @pytest.mark.parametrize("aligned", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
def test_autocast(self, aligned, x_dtype, rois_dtype): def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
self.test_forward( self.test_forward(
torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype torch.device("cuda"),
contiguous=False,
deterministic=deterministic,
aligned=aligned,
x_dtype=x_dtype,
rois_dtype=rois_dtype,
) )
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("deterministic", (True, False))
def test_backward(self, seed, device, contiguous, deterministic):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
super().test_backward(seed, device, contiguous, deterministic)
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
...@@ -978,7 +1019,6 @@ class TestDeformConv: ...@@ -978,7 +1019,6 @@ class TestDeformConv:
weight = init_weight weight = init_weight
for d in ["cpu", "cuda"]: for d in ["cpu", "cuda"]:
out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d)) out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1, mask=mask.to(d))
out.mean().backward() out.mean().backward()
if true_cpu_grads is None: if true_cpu_grads is None:
...@@ -1374,7 +1414,6 @@ class TestGeneralizedBoxIouLoss: ...@@ -1374,7 +1414,6 @@ class TestGeneralizedBoxIouLoss:
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_giou_loss(self, dtype, device): def test_giou_loss(self, dtype, device):
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)
# Identical boxes should have loss of 0 # Identical boxes should have loss of 0
......
from typing import List, Union from typing import List, Union
import torch import torch
import torch._dynamo
import torch.fx import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops from torchvision.extension import _assert_has_ops, _has_ops
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
roi_batch_ind, # [K]
y, # [K, PH, IY]
x, # [K, PW, IX]
ymask, # [K, IY]
xmask, # [K, IX]
):
_, channels, height, width = input.size()
# deal with inverse element out of feature map boundary
y = y.clamp(min=0)
x = x.clamp(min=0)
y_low = y.int()
x_low = x.int()
y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
y_low = torch.where(y_low >= height - 1, height - 1, y_low)
y = torch.where(y_low >= height - 1, y.to(input.dtype), y)
x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
x_low = torch.where(x_low >= width - 1, width - 1, x_low)
x = torch.where(x_low >= width - 1, x.to(input.dtype), x)
ly = y - y_low
lx = x - x_low
hy = 1.0 - ly
hx = 1.0 - lx
# do bilinear interpolation, but respect the masking!
# TODO: It's possible the masking here is unnecessary if y and
# x were clamped appropriately; hard to tell
def masked_index(
y, # [K, PH, IY]
x, # [K, PW, IX]
):
if ymask is not None:
assert xmask is not None
y = torch.where(ymask[:, None, :], y, 0)
x = torch.where(xmask[:, None, :], x, 0)
return input[
roi_batch_ind[:, None, None, None, None, None],
torch.arange(channels, device=input.device)[None, :, None, None, None, None],
y[:, None, :, None, :, None], # prev [K, PH, IY]
x[:, None, None, :, None, :], # prev [K, PW, IX]
] # [K, C, PH, PW, IY, IX]
v1 = masked_index(y_low, x_low)
v2 = masked_index(y_low, x_high)
v3 = masked_index(y_high, x_low)
v4 = masked_index(y_high, x_high)
# all ws preemptively [K, C, PH, PW, IY, IX]
def outer_prod(y, x):
return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]
w1 = outer_prod(hy, hx)
w2 = outer_prod(hy, lx)
w3 = outer_prod(ly, hx)
w4 = outer_prod(ly, lx)
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
# TODO: this doesn't actually cache
# TODO: main library should make this easier to do
def maybe_cast(tensor):
if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
return tensor.float()
else:
return tensor
# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype
input = maybe_cast(input)
rois = maybe_cast(rois)
_, _, height, width = input.size()
ph = torch.arange(pooled_height, device=input.device) # [PH]
pw = torch.arange(pooled_width, device=input.device) # [PW]
# input: [N, C, H, W]
# rois: [K, 5]
roi_batch_ind = rois[:, 0].int() # [K]
offset = 0.5 if aligned else 0.0
roi_start_w = rois[:, 1] * spatial_scale - offset # [K]
roi_start_h = rois[:, 2] * spatial_scale - offset # [K]
roi_end_w = rois[:, 3] * spatial_scale - offset # [K]
roi_end_h = rois[:, 4] * spatial_scale - offset # [K]
roi_width = roi_end_w - roi_start_w # [K]
roi_height = roi_end_h - roi_start_h # [K]
if not aligned:
roi_width = torch.clamp(roi_width, min=1.0) # [K]
roi_height = torch.clamp(roi_height, min=1.0) # [K]
bin_size_h = roi_height / pooled_height # [K]
bin_size_w = roi_width / pooled_width # [K]
exact_sampling = sampling_ratio > 0
roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height) # scalar or [K]
roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width) # scalar or [K]
"""
iy, ix = dims(2)
"""
if exact_sampling:
count = max(roi_bin_grid_h * roi_bin_grid_w, 1) # scalar
iy = torch.arange(roi_bin_grid_h, device=input.device) # [IY]
ix = torch.arange(roi_bin_grid_w, device=input.device) # [IX]
ymask = None
xmask = None
else:
count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1) # [K]
# When doing adaptive sampling, the number of samples we need to do
# is data-dependent based on how big the ROIs are. This is a bit
# awkward because first-class dims can't actually handle this.
# So instead, we inefficiently suppose that we needed to sample ALL
# the points and mask out things that turned out to be unnecessary
iy = torch.arange(height, device=input.device) # [IY]
ix = torch.arange(width, device=input.device) # [IX]
ymask = iy[None, :] < roi_bin_grid_h[:, None] # [K, IY]
xmask = ix[None, :] < roi_bin_grid_w[:, None] # [K, IX]
def from_K(t):
return t[:, None, None]
y = (
from_K(roi_start_h)
+ ph[None, :, None] * from_K(bin_size_h)
+ (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
) # [K, PH, IY]
x = (
from_K(roi_start_w)
+ pw[None, :, None] * from_K(bin_size_w)
+ (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
) # [K, PW, IX]
val = _bilinear_interpolate(input, roi_batch_ind, y, x, ymask, xmask) # [K, C, PH, PW, IY, IX]
# Mask out samples that weren't actually adaptively needed
if not exact_sampling:
val = torch.where(ymask[:, None, None, None, :, None], val, 0)
val = torch.where(xmask[:, None, None, None, None, :], val, 0)
output = val.sum((-1, -2)) # remove IY, IX ~> [K, C, PH, PW]
if isinstance(count, torch.Tensor):
output /= count[:, None, None, None]
else:
output /= count
output = output.to(orig_dtype)
return output
@torch.fx.wrap @torch.fx.wrap
def roi_align( def roi_align(
input: Tensor, input: Tensor,
...@@ -54,12 +226,15 @@ def roi_align( ...@@ -54,12 +226,15 @@ def roi_align(
""" """
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(roi_align) _log_api_usage_once(roi_align)
_assert_has_ops()
check_roi_boxes_shape(boxes) check_roi_boxes_shape(boxes)
rois = boxes rois = boxes
output_size = _pair(output_size) output_size = _pair(output_size)
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and input.is_cuda):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align( return torch.ops.torchvision.roi_align(
input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment