Commit dfe3fdfb authored by Stanislav Pidhorskyi's avatar Stanislav Pidhorskyi Committed by Facebook GitHub Bot
Browse files

Wrap loading torch ops into a separate function. Do not raise error if running from sphinx

Summary: This is necessary if we want to run building docs in github workflow

Differential Revision: D63497765

fbshipit-source-id: 018157c205a66584dd040882124588e499439893
parent b3060c7a
...@@ -7,12 +7,11 @@ from typing import Callable, Optional, Tuple ...@@ -7,12 +7,11 @@ from typing import Callable, Optional, Tuple
import torch as th import torch as th
import torch.nn.functional as thf import torch.nn.functional as thf
from drtk import edge_grad_ext
from drtk.interpolate import interpolate from drtk.interpolate import interpolate
from drtk.utils import index from drtk.utils import index, load_torch_ops
th.ops.load_library(edge_grad_ext.__file__) load_torch_ops("drtk.edge_grad_ext")
@th.compiler.disable @th.compiler.disable
......
...@@ -7,9 +7,9 @@ from typing import Optional ...@@ -7,9 +7,9 @@ from typing import Optional
import torch as th import torch as th
import torch.nn.functional as thf import torch.nn.functional as thf
from drtk import grid_scatter_ext from drtk.utils import load_torch_ops
th.ops.load_library(grid_scatter_ext.__file__) load_torch_ops("drtk.grid_scatter_ext")
@th.compiler.disable @th.compiler.disable
......
...@@ -9,9 +9,9 @@ attributes across the fragments, e.i. pixels covered by the primitive. ...@@ -9,9 +9,9 @@ attributes across the fragments, e.i. pixels covered by the primitive.
""" """
import torch as th import torch as th
from drtk import interpolate_ext from drtk.utils import load_torch_ops
th.ops.load_library(interpolate_ext.__file__) load_torch_ops("drtk.interpolate_ext")
@th.compiler.disable @th.compiler.disable
......
...@@ -7,9 +7,9 @@ from typing import List, Optional, Tuple ...@@ -7,9 +7,9 @@ from typing import List, Optional, Tuple
import torch as th import torch as th
import torch.nn.functional as thf import torch.nn.functional as thf
from drtk import mipmap_grid_sampler_ext from drtk.utils import load_torch_ops
th.ops.load_library(mipmap_grid_sampler_ext.__file__) load_torch_ops("drtk.mipmap_grid_sampler_ext")
@th.compiler.disable @th.compiler.disable
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch as th import torch as th
from drtk import msi_ext from drtk.utils import load_torch_ops
th.ops.load_library(msi_ext.__file__) load_torch_ops("drtk.msi_ext")
@th.compiler.disable @th.compiler.disable
......
...@@ -6,10 +6,9 @@ ...@@ -6,10 +6,9 @@
from typing import Tuple from typing import Tuple
import torch as th import torch as th
from drtk.utils import load_torch_ops
from drtk import rasterize_ext load_torch_ops("drtk.rasterize_ext")
th.ops.load_library(rasterize_ext.__file__)
@th.compiler.disable @th.compiler.disable
......
...@@ -7,10 +7,9 @@ from functools import lru_cache ...@@ -7,10 +7,9 @@ from functools import lru_cache
from typing import Tuple from typing import Tuple
import torch as th import torch as th
from drtk.utils import load_torch_ops
from drtk import render_ext load_torch_ops("drtk.render_ext")
th.ops.load_library(render_ext.__file__)
@th.compiler.disable @th.compiler.disable
......
...@@ -10,6 +10,7 @@ from drtk.utils.geometry import ( # noqa ...@@ -10,6 +10,7 @@ from drtk.utils.geometry import ( # noqa
vert_normals, # noqa vert_normals, # noqa
) )
from drtk.utils.indexing import index # noqa from drtk.utils.indexing import index # noqa
from drtk.utils.load_torch_ops import load_torch_ops # noqa
from drtk.utils.projection import ( # noqa from drtk.utils.projection import ( # noqa
DISTORTION_MODES, # noqa DISTORTION_MODES, # noqa
project_points, # noqa project_points, # noqa
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import torch as th
def load_torch_ops(extension: str) -> None:
try:
module = importlib.import_module(extension)
th.ops.load_library(module.__file__)
except ImportError as e:
import sys
# If running in sphinx, don't raise an error. That way we can build documentation without
# building extensions
if "sphinx" in sys.modules:
return
raise e
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment