Unverified Commit 00b2bbd1 authored by Matthew Tancik's avatar Matthew Tancik Committed by GitHub
Browse files

Lazy load ContractionType (#39)

parent 237b85e4
......@@ -48,6 +48,15 @@ class ContractionType(Enum):
UN_BOUNDED_TANH = 1
UN_BOUNDED_SPHERE = 2
def to_cpp_version(self):
"""Convert to the C++ version of the enum class.
Returns:
The C++ version of the enum class.
"""
return _C.ContractionTypeGetter(self.value)
@torch.no_grad()
def contract(
......@@ -65,7 +74,7 @@ def contract(
Returns:
torch.Tensor: Contracted points ([0, 1]^3).
"""
ctype = _C.ContractionType(type.value)
ctype = type.to_cpp_version()
return _C.contract(x.contiguous(), roi.contiguous(), ctype)
......@@ -85,5 +94,5 @@ def contract_inv(
Returns:
torch.Tensor: Un-contracted points.
"""
ctype = _C.ContractionType(type.value)
ctype = type.to_cpp_version()
return _C.contract_inv(x.contiguous(), roi.contiguous(), ctype)
......@@ -11,17 +11,7 @@ def _make_lazy_cuda_func(name: str) -> Callable:
return call_cuda
def _make_lazy_cuda_attribute(name: str) -> Any:
# pylint: disable=import-outside-toplevel
from ._backend import _C
if _C is None:
return None
else:
return getattr(_C, name)
ContractionType = _make_lazy_cuda_attribute("ContractionType")
ContractionTypeGetter = _make_lazy_cuda_func("ContractionType")
contract = _make_lazy_cuda_func("contract")
contract_inv = _make_lazy_cuda_func("contract_inv")
......
......@@ -4,6 +4,7 @@ import torch
from torch import Tensor
import nerfacc.cuda as _C
from nerfacc.contraction import ContractionType
from .grid import Grid
from .vol_rendering import render_visibility
......@@ -231,7 +232,7 @@ def ray_marching(
if grid is not None:
grid_roi_aabb = grid.roi_aabb
grid_binary = grid.binary
contraction_type = _C.ContractionType(grid.contraction_type.value)
contraction_type = grid.contraction_type.to_cpp_version()
else:
grid_roi_aabb = torch.tensor(
[-1e10, -1e10, -1e10, 1e10, 1e10, 1e10],
......@@ -241,7 +242,7 @@ def ray_marching(
grid_binary = torch.ones(
[1, 1, 1], dtype=torch.bool, device=rays_o.device
)
contraction_type = _C.ContractionType.AABB
contraction_type = ContractionType.AABB.to_cpp_version()
# marching with grid-based skipping
packed_info, t_starts, t_ends = _C.ray_marching(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment