Unverified Commit 674424ec authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Cub (#103)

- Faster rendering function via nvidia-cub, shipped with cuda >= 11.0 (Require >=11.6 for out use). ~10% speedup
- Expose transmittance computation.
parent bca2d4dc
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nerfacc" name = "nerfacc"
version = "0.2.4" version = "0.3.0"
description = "A General NeRF Acceleration Toolbox." description = "A General NeRF Acceleration Toolbox."
readme = "README.md" readme = "README.md"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}] authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
......
from typing import Callable from typing import Callable
import torch import torch
import tqdm
import nerfacc import nerfacc
# timing
# https://github.com/pytorch/pytorch/commit/d2784c233bfc57a1d836d961694bcc8ec4ed45e4
class Profiler: class Profiler:
def __init__(self, warmup=10, repeat=1000): def __init__(self, warmup=10, repeat=1000):
...@@ -30,6 +34,7 @@ class Profiler: ...@@ -30,6 +34,7 @@ class Profiler:
# return # return
events = prof.key_averages() events = prof.key_averages()
# print(events.table(sort_by="self_cpu_time_total", row_limit=10))
self_cpu_time_total = ( self_cpu_time_total = (
sum([event.self_cpu_time_total for event in events]) / self.repeat sum([event.self_cpu_time_total for event in events]) / self.repeat
) )
...@@ -49,15 +54,62 @@ class Profiler: ...@@ -49,15 +54,62 @@ class Profiler:
def main(): def main():
device = "cuda:0" device = "cuda:0"
torch.manual_seed(42) torch.manual_seed(42)
profiler = Profiler(warmup=10, repeat=1000) profiler = Profiler(warmup=10, repeat=100)
# contract # # contract
print("* contract") # print("* contract")
x = torch.rand([1024, 3], device=device) # x = torch.rand([1024, 3], device=device)
roi = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.float32, device=device) # roi = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.float32, device=device)
fn = lambda: nerfacc.contract( # fn = lambda: nerfacc.contract(
x, roi=roi, type=nerfacc.ContractionType.UN_BOUNDED_TANH # x, roi=roi, type=nerfacc.ContractionType.UN_BOUNDED_TANH
# )
# cpu_t, cuda_t, cuda_bytes = profiler(fn)
# print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB")
# rendering
print("* rendering")
batch_size = 81920
rays_o = torch.rand((batch_size, 3), device=device)
rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
ray_indices, t_starts, t_ends = nerfacc.ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=1e-1,
)
sigmas = torch.randn_like(t_starts, requires_grad=True)
fn = (
lambda: nerfacc.render_weight_from_density(
ray_indices, t_starts, t_ends, sigmas
)
.sum()
.backward()
)
fn()
torch.cuda.synchronize()
for _ in tqdm.tqdm(range(100)):
fn()
torch.cuda.synchronize()
cpu_t, cuda_t, cuda_bytes = profiler(fn)
print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB")
packed_info = nerfacc.pack_info(ray_indices, n_rays=batch_size).int()
fn = (
lambda: nerfacc.vol_rendering._RenderingDensity.apply(
packed_info, t_starts, t_ends, sigmas, 0
)
.sum()
.backward()
) )
fn()
torch.cuda.synchronize()
for _ in tqdm.tqdm(range(100)):
fn()
torch.cuda.synchronize()
cpu_t, cuda_t, cuda_bytes = profiler(fn) cpu_t, cuda_t, cuda_bytes = profiler(fn)
print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB") print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB")
......
import pytest import pytest
import torch import torch
from nerfacc import ray_marching from nerfacc import pack_info, ray_marching
from nerfacc.losses import distortion from nerfacc.losses import distortion
device = "cuda:0" device = "cuda:0"
...@@ -15,13 +15,14 @@ def test_distortion(): ...@@ -15,13 +15,14 @@ def test_distortion():
rays_d = torch.randn((batch_size, 3), device=device) rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_o,
rays_d, rays_d,
near_plane=0.1, near_plane=0.1,
far_plane=1.0, far_plane=1.0,
render_step_size=1e-3, render_step_size=1e-3,
) )
packed_info = pack_info(ray_indices, n_rays=batch_size)
weights = torch.rand((t_starts.shape[0],), device=device) weights = torch.rand((t_starts.shape[0],), device=device)
loss = distortion(packed_info, weights, t_starts, t_ends) loss = distortion(packed_info, weights, t_starts, t_ends)
assert loss.shape == (batch_size,) assert loss.shape == (batch_size,)
......
import pytest import pytest
import torch import torch
from nerfacc import pack_data, unpack_data, unpack_info from nerfacc import pack_data, pack_info, unpack_data, unpack_info
device = "cuda:0" device = "cuda:0"
batch_size = 32 batch_size = 32
...@@ -31,7 +31,9 @@ def test_unpack_info(): ...@@ -31,7 +31,9 @@ def test_unpack_info():
ray_indices_tgt = torch.tensor( ray_indices_tgt = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int64, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) )
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info, n_samples=5)
packed_info_2 = pack_info(ray_indices, n_rays=packed_info.shape[0])
assert torch.allclose(packed_info.int(), packed_info_2.int())
assert torch.allclose(ray_indices, ray_indices_tgt) assert torch.allclose(ray_indices, ray_indices_tgt)
......
...@@ -13,7 +13,7 @@ def test_marching_with_near_far(): ...@@ -13,7 +13,7 @@ def test_marching_with_near_far():
rays_d = torch.randn((batch_size, 3), device=device) rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_o,
rays_d, rays_d,
near_plane=0.1, near_plane=0.1,
...@@ -31,7 +31,7 @@ def test_marching_with_grid(): ...@@ -31,7 +31,7 @@ def test_marching_with_grid():
grid = OccupancyGrid(roi_aabb=[0, 0, 0, 1, 1, 1]).to(device) grid = OccupancyGrid(roi_aabb=[0, 0, 0, 1, 1, 1]).to(device)
grid._binary[:] = True grid._binary[:] = True
packed_info, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_o,
rays_d, rays_d,
grid=grid, grid=grid,
...@@ -39,7 +39,7 @@ def test_marching_with_grid(): ...@@ -39,7 +39,7 @@ def test_marching_with_grid():
far_plane=1.0, far_plane=1.0,
render_step_size=1e-2, render_step_size=1e-2,
) )
ray_indices = unpack_info(packed_info).long() ray_indices = ray_indices.long()
samples = ( samples = (
rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0 rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0
) )
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
from nerfacc import ( from nerfacc import (
accumulate_along_rays, accumulate_along_rays,
render_transmittance_from_density,
render_visibility, render_visibility,
render_weight_from_alpha, render_weight_from_alpha,
render_weight_from_density, render_weight_from_density,
...@@ -16,9 +17,9 @@ eps = 1e-6 ...@@ -16,9 +17,9 @@ eps = 1e-6
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_visibility(): def test_render_visibility():
packed_info = torch.tensor( ray_indices = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int32, device=device
) # (n_rays, 2) ) # (samples,)
alphas = torch.tensor( alphas = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
).unsqueeze( ).unsqueeze(
...@@ -26,37 +27,29 @@ def test_render_visibility(): ...@@ -26,37 +27,29 @@ def test_render_visibility():
) # (n_samples, 1) ) # (n_samples, 1)
# transmittance: [1.0, 1.0, 0.7, 0.14, 0.028] # transmittance: [1.0, 1.0, 0.7, 0.14, 0.028]
vis, packed_info_vis = render_visibility( vis = render_visibility(
packed_info, alphas, early_stop_eps=0.03, alpha_thre=0.0 alphas, ray_indices=ray_indices, early_stop_eps=0.03, alpha_thre=0.0
) )
vis_tgt = torch.tensor( vis_tgt = torch.tensor(
[True, True, True, True, False], dtype=torch.bool, device=device [True, True, True, True, False], dtype=torch.bool, device=device
) )
packed_info_vis_tgt = torch.tensor(
[[0, 1], [1, 0], [1, 3]], dtype=torch.int32, device=device
) # (n_rays, 2)
assert torch.allclose(vis, vis_tgt) assert torch.allclose(vis, vis_tgt)
assert torch.allclose(packed_info_vis, packed_info_vis_tgt)
# transmittance: [1.0, 1.0, 1.0, 0.2, 0.04] # transmittance: [1.0, 1.0, 1.0, 0.2, 0.04]
vis, packed_info_vis = render_visibility( vis = render_visibility(
packed_info, alphas, early_stop_eps=0.05, alpha_thre=0.35 alphas, ray_indices=ray_indices, early_stop_eps=0.05, alpha_thre=0.35
) )
vis_tgt = torch.tensor( vis_tgt = torch.tensor(
[True, False, True, True, False], dtype=torch.bool, device=device [True, False, True, True, False], dtype=torch.bool, device=device
) )
packed_info_vis_tgt = torch.tensor(
[[0, 1], [1, 0], [1, 2]], dtype=torch.int32, device=device
) # (n_rays, 2)
assert torch.allclose(vis, vis_tgt) assert torch.allclose(vis, vis_tgt)
assert torch.allclose(packed_info_vis, packed_info_vis_tgt)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_weight_from_alpha(): def test_render_weight_from_alpha():
packed_info = torch.tensor( ray_indices = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int32, device=device
) # (n_rays, 2) ) # (samples,)
alphas = torch.tensor( alphas = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
).unsqueeze( ).unsqueeze(
...@@ -65,64 +58,160 @@ def test_render_weight_from_alpha(): ...@@ -65,64 +58,160 @@ def test_render_weight_from_alpha():
# transmittance: [1.0, 1.0, 0.7, 0.14, 0.028] # transmittance: [1.0, 1.0, 0.7, 0.14, 0.028]
weights = render_weight_from_alpha( weights = render_weight_from_alpha(
packed_info, alphas, early_stop_eps=0.03, alpha_thre=0.0 alphas, ray_indices=ray_indices, n_rays=3
) )
weights_tgt = torch.tensor( weights_tgt = torch.tensor(
[1.0 * 0.4, 1.0 * 0.3, 0.7 * 0.8, 0.14 * 0.8, 0.0 * 0.0], [1.0 * 0.4, 1.0 * 0.3, 0.7 * 0.8, 0.14 * 0.8, 0.028 * 0.5],
dtype=torch.float32, dtype=torch.float32,
device=device, device=device,
) ).unsqueeze(-1)
assert torch.allclose(weights, weights_tgt) assert torch.allclose(weights, weights_tgt)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_weight_from_density(): def test_render_weight_from_density():
packed_info = torch.tensor( ray_indices = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int32, device=device
) # (n_rays, 2) ) # (samples,)
sigmas = torch.rand((batch_size, 1), device=device) # (n_samples, 1) sigmas = torch.rand(
(ray_indices.shape[0], 1), device=device
) # (n_samples, 1)
t_starts = torch.rand_like(sigmas) t_starts = torch.rand_like(sigmas)
t_ends = torch.rand_like(sigmas) + 1.0 t_ends = torch.rand_like(sigmas) + 1.0
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
weights = render_weight_from_density(packed_info, t_starts, t_ends, sigmas) weights = render_weight_from_density(
weights_tgt = render_weight_from_alpha(packed_info, alphas) t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=3
)
weights_tgt = render_weight_from_alpha(
alphas, ray_indices=ray_indices, n_rays=3
)
assert torch.allclose(weights, weights_tgt) assert torch.allclose(weights, weights_tgt)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_accumulate_along_rays(): def test_accumulate_along_rays():
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int32, device=device
) # (n_rays, 2) ) # (n_rays,)
weights = torch.tensor( weights = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
) ).unsqueeze(-1)
values = torch.rand((5, 2), device=device) # (n_samples, 1) values = torch.rand((5, 2), device=device) # (n_samples, 1)
ray_values = accumulate_along_rays( ray_values = accumulate_along_rays(
weights, ray_indices, values=values, n_rays=3 weights, ray_indices, values=values, n_rays=3
) )
assert ray_values.shape == (3, 2) assert ray_values.shape == (3, 2)
assert torch.allclose(ray_values[0, :], weights[0, None] * values[0, :]) assert torch.allclose(ray_values[0, :], weights[0, :] * values[0, :])
assert (ray_values[1, :] == 0).all() assert (ray_values[1, :] == 0).all()
assert torch.allclose( assert torch.allclose(
ray_values[2, :], (weights[1:, None] * values[1:]).sum(dim=0) ray_values[2, :], (weights[1:, :] * values[1:]).sum(dim=0)
) )
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_rendering(): def test_rendering():
def rgb_sigma_fn(t_starts, t_ends, ray_indices): def rgb_sigma_fn(t_starts, t_ends, ray_indices):
return torch.hstack([t_starts] * 3), t_starts return torch.hstack([t_starts] * 3), t_starts
packed_info = torch.tensor( ray_indices = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int32, device=device
) # (n_rays, 2) ) # (samples,)
sigmas = torch.rand((5, 1), device=device) # (n_samples, 1) sigmas = torch.rand(
(ray_indices.shape[0], 1), device=device
) # (n_samples, 1)
t_starts = torch.rand_like(sigmas) t_starts = torch.rand_like(sigmas)
t_ends = torch.rand_like(sigmas) + 1.0 t_ends = torch.rand_like(sigmas) + 1.0
_, _, _ = rendering( _, _, _ = rendering(
packed_info, t_starts, t_ends, rgb_sigma_fn=rgb_sigma_fn t_starts,
t_ends,
ray_indices=ray_indices,
n_rays=3,
rgb_sigma_fn=rgb_sigma_fn,
)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_grads():
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
) # (samples,)
packed_info = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device
)
sigmas = torch.tensor([[0.4], [0.8], [0.1], [0.8], [0.1]], device="cuda")
sigmas.requires_grad = True
t_starts = torch.rand_like(sigmas)
t_ends = t_starts + 1.0
weights_ref = torch.tensor(
[[0.3297], [0.5507], [0.0428], [0.2239], [0.0174]], device="cuda"
)
sigmas_grad_ref = torch.tensor(
[[0.6703], [0.1653], [0.1653], [0.1653], [0.1653]], device="cuda"
)
# naive impl. trans from sigma
trans = render_transmittance_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=3
)
weights = trans * (1.0 - torch.exp(-sigmas * (t_ends - t_starts)))
weights.sum().backward()
sigmas_grad = sigmas.grad.clone()
sigmas.grad.zero_()
assert torch.allclose(weights_ref, weights, atol=1e-4)
assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4)
# naive impl. trans from alpha
trans = render_transmittance_from_density(
t_starts, t_ends, sigmas, packed_info=packed_info, n_rays=3
)
weights = trans * (1.0 - torch.exp(-sigmas * (t_ends - t_starts)))
weights.sum().backward()
sigmas_grad = sigmas.grad.clone()
sigmas.grad.zero_()
assert torch.allclose(weights_ref, weights, atol=1e-4)
assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=3
)
weights.sum().backward()
sigmas_grad = sigmas.grad.clone()
sigmas.grad.zero_()
assert torch.allclose(weights_ref, weights, atol=1e-4)
assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, packed_info=packed_info, n_rays=3
)
weights.sum().backward()
sigmas_grad = sigmas.grad.clone()
sigmas.grad.zero_()
assert torch.allclose(weights_ref, weights, atol=1e-4)
assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
weights = render_weight_from_alpha(
alphas, ray_indices=ray_indices, n_rays=3
)
weights.sum().backward()
sigmas_grad = sigmas.grad.clone()
sigmas.grad.zero_()
assert torch.allclose(weights_ref, weights, atol=1e-4)
assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
weights = render_weight_from_alpha(
alphas, packed_info=packed_info, n_rays=3
) )
weights.sum().backward()
sigmas_grad = sigmas.grad.clone()
sigmas.grad.zero_()
assert torch.allclose(weights_ref, weights, atol=1e-4)
assert torch.allclose(sigmas_grad_ref, sigmas_grad, atol=1e-4)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -131,3 +220,4 @@ if __name__ == "__main__": ...@@ -131,3 +220,4 @@ if __name__ == "__main__":
test_render_weight_from_density() test_render_weight_from_density()
test_accumulate_along_rays() test_accumulate_along_rays()
test_rendering() test_rendering()
test_grads()
import pytest import pytest
import torch import torch
from nerfacc import ray_marching, ray_resampling from nerfacc import pack_info, ray_marching, ray_resampling
device = "cuda:0" device = "cuda:0"
batch_size = 128 batch_size = 128
...@@ -13,13 +13,14 @@ def test_resampling(): ...@@ -13,13 +13,14 @@ def test_resampling():
rays_d = torch.randn((batch_size, 3), device=device) rays_d = torch.randn((batch_size, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_o,
rays_d, rays_d,
near_plane=0.1, near_plane=0.1,
far_plane=1.0, far_plane=1.0,
render_step_size=1e-3, render_step_size=1e-3,
) )
packed_info = pack_info(ray_indices, n_rays=batch_size)
weights = torch.rand((t_starts.shape[0],), device=device) weights = torch.rand((t_starts.shape[0],), device=device)
packed_info, t_starts, t_ends = ray_resampling( packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=32 packed_info, t_starts, t_ends, weights, n_samples=32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment