Commit e2744a50 authored by Gabe Schwartz's avatar Gabe Schwartz Committed by Facebook GitHub Bot
Browse files

Decorate custom C++ ops w/compiler disable guard.

Summary: In order to work properly with `torch.compile()`, we need to decorate any function that calls a custom C++/CUDA extension with `torch.compiler.disable` so that it knows to insert a graph break.

Reviewed By: podgorskiy

Differential Revision: D59776177

fbshipit-source-id: d80eb43858836f8b8647d2a35b30d0b863989e94
parent b26a607c
...@@ -16,6 +16,7 @@ from drtk.utils import index ...@@ -16,6 +16,7 @@ from drtk.utils import index
th.ops.load_library(edge_grad_ext.__file__) th.ops.load_library(edge_grad_ext.__file__)
@th.compiler.disable
def edge_grad_estimator( def edge_grad_estimator(
v_pix: th.Tensor, v_pix: th.Tensor,
vi: th.Tensor, vi: th.Tensor,
...@@ -126,7 +127,6 @@ def edge_grad_estimator_ref( ...@@ -126,7 +127,6 @@ def edge_grad_estimator_ref(
class EdgeGradEstimatorFunction(th.autograd.Function): class EdgeGradEstimatorFunction(th.autograd.Function):
@staticmethod @staticmethod
@th.cuda.amp.custom_fwd(cast_inputs=th.float32)
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward( def forward(
ctx, ctx,
...@@ -140,7 +140,6 @@ class EdgeGradEstimatorFunction(th.autograd.Function): ...@@ -140,7 +140,6 @@ class EdgeGradEstimatorFunction(th.autograd.Function):
return img return img
@staticmethod @staticmethod
@th.cuda.amp.custom_bwd
# pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently. # pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently.
def backward(ctx, grad_output: th.Tensor) -> Tuple[ def backward(ctx, grad_output: th.Tensor) -> Tuple[
Optional[th.Tensor], Optional[th.Tensor],
......
...@@ -10,6 +10,7 @@ from drtk import interpolate_ext ...@@ -10,6 +10,7 @@ from drtk import interpolate_ext
th.ops.load_library(interpolate_ext.__file__) th.ops.load_library(interpolate_ext.__file__)
@th.compiler.disable
def interpolate( def interpolate(
vert_attributes: th.Tensor, vert_attributes: th.Tensor,
vi: th.Tensor, vi: th.Tensor,
......
...@@ -13,6 +13,7 @@ from drtk import mipmap_grid_sampler_ext ...@@ -13,6 +13,7 @@ from drtk import mipmap_grid_sampler_ext
th.ops.load_library(mipmap_grid_sampler_ext.__file__) th.ops.load_library(mipmap_grid_sampler_ext.__file__)
@th.compiler.disable
def mipmap_grid_sample( def mipmap_grid_sample(
input: List[th.Tensor], input: List[th.Tensor],
grid: th.Tensor, grid: th.Tensor,
......
...@@ -10,6 +10,7 @@ from drtk import msi_ext ...@@ -10,6 +10,7 @@ from drtk import msi_ext
th.ops.load_library(msi_ext.__file__) th.ops.load_library(msi_ext.__file__)
@th.compiler.disable
def msi( def msi(
ray_o: th.Tensor, ray_o: th.Tensor,
ray_d: th.Tensor, ray_d: th.Tensor,
......
...@@ -13,6 +13,7 @@ from drtk import rasterize_ext ...@@ -13,6 +13,7 @@ from drtk import rasterize_ext
th.ops.load_library(rasterize_ext.__file__) th.ops.load_library(rasterize_ext.__file__)
@th.compiler.disable
def rasterize( def rasterize(
v: th.Tensor, v: th.Tensor,
vi: th.Tensor, vi: th.Tensor,
...@@ -59,6 +60,7 @@ def rasterize( ...@@ -59,6 +60,7 @@ def rasterize(
return index_img return index_img
@th.compiler.disable
def rasterize_with_depth( def rasterize_with_depth(
v: th.Tensor, v: th.Tensor,
vi: th.Tensor, vi: th.Tensor,
......
...@@ -14,6 +14,7 @@ from drtk import render_ext ...@@ -14,6 +14,7 @@ from drtk import render_ext
th.ops.load_library(render_ext.__file__) th.ops.load_library(render_ext.__file__)
@th.compiler.disable
def render( def render(
v: th.Tensor, v: th.Tensor,
vi: th.Tensor, vi: th.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment