"server/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "369de832cdca7680c8f50ba196d39172a895fcad"
Commit b4286720 authored by Ruilong Li's avatar Ruilong Li
Browse files

ray_pdf_query

parent 488bca66
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import math
from typing import Callable, List, Union from typing import Callable, List, Union
import torch import torch
...@@ -73,8 +74,11 @@ class NGPradianceField(torch.nn.Module): ...@@ -73,8 +74,11 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True, use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1), density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False, unbounded: bool = False,
hidden_dim: int = 64,
geo_feat_dim: int = 15, geo_feat_dim: int = 15,
n_levels: int = 16, n_levels: int = 16,
max_res: int = 1024,
base_res: int = 16,
log2_hashmap_size: int = 19, log2_hashmap_size: int = 19,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -87,7 +91,9 @@ class NGPradianceField(torch.nn.Module): ...@@ -87,7 +91,9 @@ class NGPradianceField(torch.nn.Module):
self.unbounded = unbounded self.unbounded = unbounded
self.geo_feat_dim = geo_feat_dim self.geo_feat_dim = geo_feat_dim
per_level_scale = 1.4472692012786865 per_level_scale = math.exp(
(math.log(max_res) - math.log(base_res)) / (n_levels - 1)
)
if self.use_viewdirs: if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding( self.direction_encoding = tcnn.Encoding(
...@@ -113,14 +119,14 @@ class NGPradianceField(torch.nn.Module): ...@@ -113,14 +119,14 @@ class NGPradianceField(torch.nn.Module):
"n_levels": n_levels, "n_levels": n_levels,
"n_features_per_level": 2, "n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size, "log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16, "base_resolution": base_res,
"per_level_scale": per_level_scale, "per_level_scale": per_level_scale,
}, },
network_config={ network_config={
"otype": "FullyFusedMLP", "otype": "FullyFusedMLP",
"activation": "ReLU", "activation": "ReLU",
"output_activation": "None", "output_activation": "None",
"n_neurons": 64, "n_neurons": hidden_dim,
"n_hidden_layers": 1, "n_hidden_layers": 1,
}, },
) )
......
...@@ -24,6 +24,7 @@ grid_query = _make_lazy_cuda_func("grid_query") ...@@ -24,6 +24,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching") ray_marching = _make_lazy_cuda_func("ray_marching")
ray_resampling = _make_lazy_cuda_func("ray_resampling") ray_resampling = _make_lazy_cuda_func("ray_resampling")
ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query")
is_cub_available = _make_lazy_cuda_func("is_cub_available") is_cub_available = _make_lazy_cuda_func("is_cub_available")
transmittance_from_sigma_forward_cub = _make_lazy_cuda_func( transmittance_from_sigma_forward_cub = _make_lazy_cuda_func(
......
...@@ -4,6 +4,90 @@ ...@@ -4,6 +4,90 @@
#include "include/helpers_cuda.h" #include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void pdf_query_kernel(
const uint32_t n_rays,
// query
const int *packed_info, // input ray & point indices.
const scalar_t *starts, // input start t
const scalar_t *ends, // input end t
const scalar_t *pdfs, // pdf to be queried
// resample
const int *resample_packed_info, // input ray & point indices.
const scalar_t *resample_starts, // input start t, sorted
const scalar_t *resample_ends, // input end t, sorted
// output
scalar_t *resample_pdfs) // should be zero-initialized
{
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int resample_base = resample_packed_info[i * 2 + 0]; // point idx start.
const int resample_steps = resample_packed_info[i * 2 + 1]; // point idx shift.
if (resample_steps == 0) // nothing to query
return;
if (steps == 0) // nothing to be queried: set pdfs to 0
return;
starts += base;
ends += base;
pdfs += base;
resample_starts += resample_base;
resample_ends += resample_base;
resample_pdfs += resample_base;
// which interval is resample_start (t0) located
int t0_id = -1;
scalar_t t0_start = 0.0f, t0_end = starts[0];
scalar_t cdf0_start = 0.0f, cdf0_end = 0.0f;
// which interval is resample_end (t1) located
int t1_id = -1;
scalar_t t1_start = 0.0f, t1_end = starts[0];
scalar_t cdf1_start = 0.0f, cdf1_end = 0.0f;
// go!
for (int j = 0; j < resample_steps; ++j)
{
scalar_t t0 = resample_starts[j];
while(t0 > t0_end & t0_id < steps - 1) {
t0_id++;
t0_start = starts[t0_id];
t0_end = ends[t0_id];
cdf0_start = cdf0_end;
cdf0_end += pdfs[t0_id];
}
if (t0 > t0_end) {
resample_pdfs[j] = 0.0f;
continue;
}
scalar_t pct0 = 0.0f; // max(t0 - t0_start, 0.0f) / max(t0_end - t0_start, 1e-10f);
scalar_t resample_cdf_start = cdf0_start + pct0 * (cdf0_end - cdf0_start);
scalar_t t1 = resample_ends[j];
while(t1 > t1_end & t1_id < steps - 1) {
t1_id++;
t1_start = starts[t1_id];
t1_end = ends[t1_id];
cdf1_start = cdf1_end;
cdf1_end += pdfs[t1_id];
}
if (t1 > t1_end) {
resample_pdfs[j] = cdf1_end - resample_cdf_start;
continue;
}
scalar_t pct1 = 1.0f; // max(t1 - t1_start, 0.0f) / max(t1_end - t1_start, 1e-10f);
scalar_t resample_cdf_end = cdf1_start + pct1 * (cdf1_end - cdf1_start);
// compute pdf of [t0, t1]
resample_pdfs[j] = resample_cdf_end - resample_cdf_start;
}
return;
}
template <typename scalar_t> template <typename scalar_t>
__global__ void cdf_resampling_kernel( __global__ void cdf_resampling_kernel(
const uint32_t n_rays, const uint32_t n_rays,
...@@ -201,3 +285,52 @@ std::vector<torch::Tensor> ray_resampling( ...@@ -201,3 +285,52 @@ std::vector<torch::Tensor> ray_resampling(
return {resample_packed_info, resample_starts, resample_ends}; return {resample_packed_info, resample_starts, resample_ends};
} }
torch::Tensor ray_pdf_query(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor pdfs,
torch::Tensor resample_packed_info,
torch::Tensor resample_starts,
torch::Tensor resample_ends)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(pdfs);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(pdfs.ndimension() == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_resamples = resample_starts.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor resample_pdfs = torch::zeros({n_resamples}, pdfs.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pdfs.scalar_type(),
"pdf_query",
([&]
{ pdf_query_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
// inputs
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
pdfs.data_ptr<scalar_t>(),
resample_packed_info.data_ptr<int>(),
resample_starts.data_ptr<scalar_t>(),
resample_ends.data_ptr<scalar_t>(),
// outputs
resample_pdfs.data_ptr<scalar_t>()); }));
return resample_pdfs;
}
...@@ -58,6 +58,15 @@ std::vector<torch::Tensor> ray_resampling( ...@@ -58,6 +58,15 @@ std::vector<torch::Tensor> ray_resampling(
torch::Tensor weights, torch::Tensor weights,
const int steps); const int steps);
torch::Tensor ray_pdf_query(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor pdfs,
torch::Tensor resample_packed_info,
torch::Tensor resample_starts,
torch::Tensor resample_ends);
torch::Tensor unpack_data( torch::Tensor unpack_data(
torch::Tensor packed_info, torch::Tensor packed_info,
torch::Tensor data, torch::Tensor data,
...@@ -145,6 +154,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -145,6 +154,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("ray_aabb_intersect", &ray_aabb_intersect); m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching); m.def("ray_marching", &ray_marching);
m.def("ray_resampling", &ray_resampling); m.def("ray_resampling", &ray_resampling);
m.def("ray_pdf_query", &ray_pdf_query);
// rendering // rendering
m.def("is_cub_available", is_cub_available); m.def("is_cub_available", is_cub_available);
......
...@@ -126,7 +126,7 @@ def rendering( ...@@ -126,7 +126,7 @@ def rendering(
if render_bkgd is not None: if render_bkgd is not None:
colors = colors + render_bkgd * (1.0 - opacities) colors = colors + render_bkgd * (1.0 - opacities)
return colors, opacities, depths return colors, opacities, depths, weights
def accumulate_along_rays( def accumulate_along_rays(
......
from struct import pack
import torch
from torch import Tensor
from nerfacc import ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0"
def outer(
t0_starts: Tensor,
t0_ends: Tensor,
t1_starts: Tensor,
t1_ends: Tensor,
y1: Tensor,
) -> Tensor:
cy1 = torch.cat(
[torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1
)
idx_lo = (
torch.searchsorted(
t1_starts.contiguous(), t0_starts.contiguous(), side="right"
)
- 1
)
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1)
idx_hi = torch.searchsorted(
t1_ends.contiguous(), t0_ends.contiguous(), side="right"
)
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1)
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1)
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1)
y0_outer = cy1_hi - cy1_lo
return y0_outer
def test_pdf_query():
n_rays = 1
rays_o = torch.rand((n_rays, 3), device=device)
rays_d = torch.randn((n_rays, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
)
weights = torch.rand((t_starts.shape[0],), device=device)
packed_info_new = packed_info
t_starts_new = t_starts - 0.3
t_ends_new = t_ends - 0.3
weights_new_ref = outer(
t_starts_new.reshape(n_rays, -1),
t_ends_new.reshape(n_rays, -1),
t_starts.reshape(n_rays, -1),
t_ends.reshape(n_rays, -1),
weights.reshape(n_rays, -1),
)
weights_new_ref = weights_new_ref.flatten()
weights_new = ray_pdf_query(
packed_info,
t_starts,
t_ends,
weights,
packed_info_new,
t_starts_new,
t_ends_new,
)
weights_new = weights_new.flatten()
print(weights)
print(weights_new_ref)
print(weights_new)
if __name__ == "__main__":
test_pdf_query()
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import torch import torch
from nerfacc import pack_info, ray_marching, ray_resampling from nerfacc import pack_info, ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0" device = "cuda:0"
batch_size = 128 batch_size = 128
...@@ -28,5 +29,29 @@ def test_resampling(): ...@@ -28,5 +29,29 @@ def test_resampling():
assert t_starts.shape == t_ends.shape == (batch_size * 32, 1) assert t_starts.shape == t_ends.shape == (batch_size * 32, 1)
def test_pdf_query():
rays_o = torch.rand((1, 3), device=device)
rays_d = torch.randn((1, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
)
weights = torch.rand((t_starts.shape[0],), device=device)
weights_new = ray_pdf_query(
packed_info,
t_starts,
t_ends,
weights,
packed_info,
t_starts + 0.3,
t_ends + 0.3,
)
if __name__ == "__main__": if __name__ == "__main__":
test_resampling() test_resampling()
test_pdf_query()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment