"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9003d75f20c020adf1adeb0ab4a2e39e352ce891"
Unverified Commit b73def22 authored by Matthew Tancik's avatar Matthew Tancik Committed by GitHub
Browse files

Lazy load cuda (#4)

parent 497d404e
version: 2 version: 2
build:
os: ubuntu-20.04
tools:
python: "3.9"
sphinx: sphinx:
fail_on_warning: true fail_on_warning: true
configuration: docs/source/conf.py
python: python:
version: 3.8
install: install:
# Equivalent to 'pip install .' # Equivalent to 'pip install .'
- method: pip - method: pip
......
from ._backend import _C from typing import Callable
def _make_lazy_cuda(name: str) -> Callable:
def call_cuda(*args, **kwargs):
# pylint: disable=import-outside-toplevel
from ._backend import _C
return getattr(_C, name)(*args, **kwargs)
return call_cuda
ray_aabb_intersect = _make_lazy_cuda("ray_aabb_intersect")
volumetric_marching = _make_lazy_cuda("volumetric_marching")
volumetric_rendering_steps = _make_lazy_cuda("volumetric_rendering_steps")
volumetric_rendering_weights_forward = _make_lazy_cuda("volumetric_rendering_weights_forward")
volumetric_rendering_weights_backward = _make_lazy_cuda("volumetric_rendering_weights_backward")
\ No newline at end of file
...@@ -2,7 +2,7 @@ from typing import Tuple ...@@ -2,7 +2,7 @@ from typing import Tuple
import torch import torch
from .cuda import _C import nerfacc.cuda as nerfacc_cuda
@torch.no_grad() @torch.no_grad()
...@@ -27,7 +27,7 @@ def ray_aabb_intersect( ...@@ -27,7 +27,7 @@ def ray_aabb_intersect(
rays_o = rays_o.contiguous() rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous() rays_d = rays_d.contiguous()
aabb = aabb.contiguous() aabb = aabb.contiguous()
t_min, t_max = _C.ray_aabb_intersect(rays_o, rays_d, aabb) t_min, t_max = nerfacc_cuda.ray_aabb_intersect(rays_o, rays_d, aabb)
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return t_min, t_max return t_min, t_max
...@@ -87,7 +87,7 @@ def volumetric_marching( ...@@ -87,7 +87,7 @@ def volumetric_marching(
frustum_dirs, frustum_dirs,
frustum_starts, frustum_starts,
frustum_ends, frustum_ends,
) = _C.volumetric_marching( ) = nerfacc_cuda.volumetric_marching(
# rays # rays
rays_o.contiguous(), rays_o.contiguous(),
rays_d.contiguous(), rays_d.contiguous(),
...@@ -152,7 +152,7 @@ def volumetric_rendering_steps( ...@@ -152,7 +152,7 @@ def volumetric_rendering_steps(
frustum_starts = frustum_starts.contiguous() frustum_starts = frustum_starts.contiguous()
frustum_ends = frustum_ends.contiguous() frustum_ends = frustum_ends.contiguous()
sigmas = sigmas.contiguous() sigmas = sigmas.contiguous()
compact_packed_info, compact_selector = _C.volumetric_rendering_steps( compact_packed_info, compact_selector = nerfacc_cuda.volumetric_rendering_steps(
packed_info, frustum_starts, frustum_ends, sigmas packed_info, frustum_starts, frustum_ends, sigmas
) )
compact_frustum_starts = frustum_starts[compact_selector] compact_frustum_starts = frustum_starts[compact_selector]
...@@ -261,7 +261,7 @@ def volumetric_rendering_accumulate( ...@@ -261,7 +261,7 @@ def volumetric_rendering_accumulate(
class _volumetric_rendering_weights(torch.autograd.Function): class _volumetric_rendering_weights(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas): def forward(ctx, packed_info, frustum_starts, frustum_ends, sigmas):
weights, ray_indices = _C.volumetric_rendering_weights_forward( weights, ray_indices = nerfacc_cuda.volumetric_rendering_weights_forward(
packed_info, frustum_starts, frustum_ends, sigmas packed_info, frustum_starts, frustum_ends, sigmas
) )
ctx.save_for_backward( ctx.save_for_backward(
...@@ -282,7 +282,7 @@ class _volumetric_rendering_weights(torch.autograd.Function): ...@@ -282,7 +282,7 @@ class _volumetric_rendering_weights(torch.autograd.Function):
sigmas, sigmas,
weights, weights,
) = ctx.saved_tensors ) = ctx.saved_tensors
grad_sigmas = _C.volumetric_rendering_weights_backward( grad_sigmas = nerfacc_cuda.volumetric_rendering_weights_backward(
weights, weights,
grad_weights, grad_weights,
packed_info, packed_info,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment