Commit b4286720 authored by Ruilong Li's avatar Ruilong Li
Browse files

ray_pdf_query

parent 488bca66
......@@ -2,6 +2,7 @@
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import math
from typing import Callable, List, Union
import torch
......@@ -73,8 +74,11 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False,
hidden_dim: int = 64,
geo_feat_dim: int = 15,
n_levels: int = 16,
max_res: int = 1024,
base_res: int = 16,
log2_hashmap_size: int = 19,
) -> None:
super().__init__()
......@@ -87,7 +91,9 @@ class NGPradianceField(torch.nn.Module):
self.unbounded = unbounded
self.geo_feat_dim = geo_feat_dim
per_level_scale = 1.4472692012786865
per_level_scale = math.exp(
(math.log(max_res) - math.log(base_res)) / (n_levels - 1)
)
if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding(
......@@ -113,14 +119,14 @@ class NGPradianceField(torch.nn.Module):
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16,
"base_resolution": base_res,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_neurons": hidden_dim,
"n_hidden_layers": 1,
},
)
......
......@@ -24,6 +24,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching")
ray_resampling = _make_lazy_cuda_func("ray_resampling")
ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query")
is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
......
......@@ -4,6 +4,90 @@
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void pdf_query_kernel(
const uint32_t n_rays,
// query
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *pdfs, // pdf to be queried
// resample
const int *resample_packed_info, // input ray & point indices.
const scalar_t *resample_starts, // input start t, sorted
const scalar_t *resample_ends, // input end t, sorted
// output
scalar_t *resample_pdfs) // should be zero-initialized
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0]; // point idx start.
const int resample_steps = resample_packed_info[i * 2 + 1]; // point idx shift.
if (resample_steps == 0) // nothing to query
return;
if (steps == 0) // nothing to be queried: set pdfs to 0
return;
starts += base;
ends += base;
pdfs += base;
resample_starts += resample_base;
resample_ends += resample_base;
resample_pdfs += resample_base;
// which interval is resample_start (t0) located
int t0_id = -1;
scalar_t t0_start = 0.0f, t0_end = starts[0];
scalar_t cdf0_start = 0.0f, cdf0_end = 0.0f;
// which interval is resample_end (t1) located
int t1_id = -1;
scalar_t t1_start = 0.0f, t1_end = starts[0];
scalar_t cdf1_start = 0.0f, cdf1_end = 0.0f;
// go!
for (int j = 0; j < resample_steps; ++j)
{
scalar_t t0 = resample_starts[j];
while(t0 > t0_end & t0_id < steps - 1) {
t0_id++;
t0_start = starts[t0_id];
t0_end = ends[t0_id];
cdf0_start = cdf0_end;
cdf0_end += pdfs[t0_id];
}
if (t0 > t0_end) {
resample_pdfs[j] = 0.0f;
continue;
}
scalar_t pct0 = 0.0f; // max(t0 - t0_start, 0.0f) / max(t0_end - t0_start, 1e-10f);
scalar_t resample_cdf_start = cdf0_start + pct0 * (cdf0_end - cdf0_start);
scalar_t t1 = resample_ends[j];
while(t1 > t1_end & t1_id < steps - 1) {
t1_id++;
t1_start = starts[t1_id];
t1_end = ends[t1_id];
cdf1_start = cdf1_end;
cdf1_end += pdfs[t1_id];
}
if (t1 > t1_end) {
resample_pdfs[j] = cdf1_end - resample_cdf_start;
continue;
}
scalar_t pct1 = 1.0f; // max(t1 - t1_start, 0.0f) / max(t1_end - t1_start, 1e-10f);
scalar_t resample_cdf_end = cdf1_start + pct1 * (cdf1_end - cdf1_start);
// compute pdf of [t0, t1]
resample_pdfs[j] = resample_cdf_end - resample_cdf_start;
}
return;
}
template <typename scalar_t>
__global__ void cdf_resampling_kernel(
const uint32_t n_rays,
......@@ -201,3 +285,52 @@ std::vector<torch::Tensor> ray_resampling(
return {resample_packed_info, resample_starts, resample_ends};
}
torch::Tensor ray_pdf_query(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor pdfs,
torch::Tensor resample_packed_info,
torch::Tensor resample_starts,
torch::Tensor resample_ends)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(pdfs);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(pdfs.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_resamples = resample_starts.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor resample_pdfs = torch::zeros({n_resamples}, pdfs.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pdfs.scalar_type(),
"pdf_query",
([&]
{ pdf_query_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
pdfs.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>(),
// outputs
resample_pdfs.data_ptr<scalar_t>()); }));
return resample_pdfs;
}
......@@ -58,6 +58,15 @@ std::vector<torch::Tensor> ray_resampling(
torch::Tensor weights,
const int steps);
torch::Tensor ray_pdf_query(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor pdfs,
torch::Tensor resample_packed_info,
torch::Tensor resample_starts,
torch::Tensor resample_ends);
torch::Tensor unpack_data(
torch::Tensor packed_info,
torch::Tensor data,
......@@ -145,6 +154,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
m.def("ray_resampling", &ray_resampling);
m.def("ray_pdf_query", &ray_pdf_query);
// rendering
m.def("is_cub_available", is_cub_available);
......
......@@ -126,7 +126,7 @@ def rendering(
if render_bkgd is not None:
colors = colors + render_bkgd * (1.0 - opacities)
return colors, opacities, depths
return colors, opacities, depths, weights
def accumulate_along_rays(
......
from struct import pack
import torch
from torch import Tensor
from nerfacc import ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0"
def outer(
t0_starts: Tensor,
t0_ends: Tensor,
t1_starts: Tensor,
t1_ends: Tensor,
y1: Tensor,
) -> Tensor:
cy1 = torch.cat(
[torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
)
idx_lo = (
torch.searchsorted(
t1_starts.contiguous(), t0_starts.contiguous(), side="right"
)
- 1
)
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
idx_hi = torch.searchsorted(
t1_ends.contiguous(), t0_ends.contiguous(), side="right"
)
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
y0_outer = cy1_hi - cy1_lo
return y0_outer
def test_pdf_query():
n_rays = 1
rays_o = torch.rand((n_rays, 3), device=device)
rays_d = torch.randn((n_rays, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
)
weights = torch.rand((t_starts.shape[0],), device=device)
packed_info_new = packed_info
t_starts_new = t_starts - 0.3
t_ends_new = t_ends - 0.3
weights_new_ref = outer(
t_starts_new.reshape(n_rays, -1),
t_ends_new.reshape(n_rays, -1),
t_starts.reshape(n_rays, -1),
t_ends.reshape(n_rays, -1),
weights.reshape(n_rays, -1),
)
weights_new_ref = weights_new_ref.flatten()
weights_new = ray_pdf_query(
packed_info,
t_starts,
t_ends,
weights,
packed_info_new,
t_starts_new,
t_ends_new,
)
weights_new = weights_new.flatten()
print(weights)
print(weights_new_ref)
print(weights_new)
if __name__ == "__main__":
test_pdf_query()
......@@ -2,6 +2,7 @@ import pytest
import torch
from nerfacc import pack_info, ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0"
batch_size = 128
......@@ -28,5 +29,29 @@ def test_resampling():
assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)
def test_pdf_query():
rays_o = torch.rand((1, 3), device=device)
rays_d = torch.randn((1, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
)
weights = torch.rand((t_starts.shape[0],), device=device)
weights_new = ray_pdf_query(
packed_info,
t_starts,
t_ends,
weights,
packed_info,
t_starts + 0.3,
t_ends + 0.3,
)
if __name__ == "__main__":
test_resampling()
test_pdf_query()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment