Commit dfe3fdfb authored by Stanislav Pidhorskyi's avatar Stanislav Pidhorskyi Committed by Facebook GitHub Bot
Browse files

Wrap loading torch ops into a separate function. Do not raise error if running from sphinx

Summary: This is necessary if we want to run building docs in github workflow

Differential Revision: D63497765

fbshipit-source-id: 018157c205a66584dd040882124588e499439893
parent b3060c7a
......@@ -7,12 +7,11 @@ from typing import Callable, Optional, Tuple
import torch as th
import torch.nn.functional as thf
from drtk import edge_grad_ext
from drtk.interpolate import interpolate
from drtk.utils import index
from drtk.utils import index, load_torch_ops
th.ops.load_library(edge_grad_ext.__file__)
load_torch_ops("drtk.edge_grad_ext")
@th.compiler.disable
......
......@@ -7,9 +7,9 @@ from typing import Optional
import torch as th
import torch.nn.functional as thf
from drtk import grid_scatter_ext
from drtk.utils import load_torch_ops
th.ops.load_library(grid_scatter_ext.__file__)
load_torch_ops("drtk.grid_scatter_ext")
@th.compiler.disable
......
......@@ -9,9 +9,9 @@ attributes across the fragments, e.i. pixels covered by the primitive.
"""
import torch as th
from drtk import interpolate_ext
from drtk.utils import load_torch_ops
th.ops.load_library(interpolate_ext.__file__)
load_torch_ops("drtk.interpolate_ext")
@th.compiler.disable
......
......@@ -7,9 +7,9 @@ from typing import List, Optional, Tuple
import torch as th
import torch.nn.functional as thf
from drtk import mipmap_grid_sampler_ext
from drtk.utils import load_torch_ops
th.ops.load_library(mipmap_grid_sampler_ext.__file__)
load_torch_ops("drtk.mipmap_grid_sampler_ext")
@th.compiler.disable
......
......@@ -4,9 +4,9 @@
# LICENSE file in the root directory of this source tree.
import torch as th
from drtk import msi_ext
from drtk.utils import load_torch_ops
th.ops.load_library(msi_ext.__file__)
load_torch_ops("drtk.msi_ext")
@th.compiler.disable
......
......@@ -6,10 +6,9 @@
from typing import Tuple
import torch as th
from drtk.utils import load_torch_ops
from drtk import rasterize_ext
th.ops.load_library(rasterize_ext.__file__)
load_torch_ops("drtk.rasterize_ext")
@th.compiler.disable
......
......@@ -7,10 +7,9 @@ from functools import lru_cache
from typing import Tuple
import torch as th
from drtk.utils import load_torch_ops
from drtk import render_ext
th.ops.load_library(render_ext.__file__)
load_torch_ops("drtk.render_ext")
@th.compiler.disable
......
......@@ -10,6 +10,7 @@ from drtk.utils.geometry import ( # noqa
vert_normals, # noqa
)
from drtk.utils.indexing import index # noqa
from drtk.utils.load_torch_ops import load_torch_ops # noqa
from drtk.utils.projection import ( # noqa
DISTORTION_MODES, # noqa
project_points, # noqa
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import torch as th
def load_torch_ops(extension: str) -> None:
try:
module = importlib.import_module(extension)
th.ops.load_library(module.__file__)
except ImportError as e:
import sys
# If running in sphinx, don't raise an error. That way we can build documentation without
# building extensions
if "sphinx" in sys.modules:
return
raise e
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment