Unverified Commit a5f531af authored by Edward Z. Yang's avatar Edward Z. Yang Committed by GitHub
Browse files

Force use of torch.compile on deterministic roi_align implementation (#8436)


Signed-off-by: default avatarEdward Z. Yang <ezyang@meta.com>
Co-authored-by: default avatarNicolas Hug <nh.nicolas.hug@gmail.com>
parent 775dd2d8
...@@ -14,6 +14,7 @@ import torch.testing._internal.optests as optests ...@@ -14,6 +14,7 @@ import torch.testing._internal.optests as optests
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image from PIL import Image
from torch import nn, Tensor from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import models, ops from torchvision import models, ops
...@@ -529,6 +530,10 @@ class TestRoIAlign(RoIOpTester): ...@@ -529,6 +530,10 @@ class TestRoIAlign(RoIOpTester):
def test_backward(self, seed, device, contiguous, deterministic): def test_backward(self, seed, device, contiguous, deterministic):
if deterministic and device == "cpu": if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest") pytest.skip("cpu is always deterministic, don't retest")
if deterministic and device == "mps":
pytest.skip("no deterministic implementation for mps")
if deterministic and not is_compile_supported(device):
pytest.skip("deterministic implementation only if torch.compile supported")
super().test_backward(seed, device, contiguous, deterministic) super().test_backward(seed, device, contiguous, deterministic)
def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
......
import functools
from typing import List, Union from typing import List, Union
import torch import torch
import torch._dynamo import torch._dynamo
import torch.fx import torch.fx
from torch import nn, Tensor from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.jit.annotations import BroadcastingList2 from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops, _has_ops from torchvision.extension import _assert_has_ops, _has_ops
...@@ -12,6 +14,24 @@ from ..utils import _log_api_usage_once ...@@ -12,6 +14,24 @@ from ..utils import _log_api_usage_once
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format
def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""
def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)
return compile_hook
return decorate_fn
# NB: all inputs are tensors # NB: all inputs are tensors
def _bilinear_interpolate( def _bilinear_interpolate(
input, # [N, C, H, W] input, # [N, C, H, W]
...@@ -86,15 +106,13 @@ def maybe_cast(tensor): ...@@ -86,15 +106,13 @@ def maybe_cast(tensor):
return tensor return tensor
# This is a slow but pure Python and differentiable implementation of # This is a pure Python and differentiable implementation of roi_align. When
# roi_align. It potentially is a good basis for Inductor compilation # run in eager mode, it uses a lot of memory, but when compiled it has
# (but I have not benchmarked it) but today it is solely used for the # acceptable memory usage. The main point of this implementation is that
# fact that its backwards can be implemented deterministically, # its backwards is deterministic.
# which is needed for the PT2 benchmark suite.
#
# It is transcribed directly off of the roi_align CUDA kernel, see # It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph @lazy_compile(dynamic=True)
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype orig_dtype = input.dtype
...@@ -232,7 +250,9 @@ def roi_align( ...@@ -232,7 +250,9 @@ def roi_align(
if not isinstance(rois, torch.Tensor): if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois) rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): if (
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
) and is_compile_supported(input.device.type):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops() _assert_has_ops()
return torch.ops.torchvision.roi_align( return torch.ops.torchvision.roi_align(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment