Commit bc62c5b6 authored by Ruilong Li's avatar Ruilong Li
Browse files

merge and fix tests

parent 52512811
...@@ -256,7 +256,7 @@ std::vector<torch::Tensor> ray_resampling( ...@@ -256,7 +256,7 @@ std::vector<torch::Tensor> ray_resampling(
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2); TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1); TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
const uint32_t n_rays = packed_info.size(0); const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0); const uint32_t n_samples = weights.size(0);
...@@ -312,7 +312,7 @@ torch::Tensor ray_pdf_query( ...@@ -312,7 +312,7 @@ torch::Tensor ray_pdf_query(
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2); TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(pdfs.ndimension() == 1); TORCH_CHECK(pdfs.ndimension() == 2 & pdfs.size(1) == 1);
const uint32_t n_rays = packed_info.size(0); const uint32_t n_rays = packed_info.size(0);
const uint32_t n_resamples = resample_starts.size(0); const uint32_t n_resamples = resample_starts.size(0);
...@@ -320,7 +320,7 @@ torch::Tensor ray_pdf_query( ...@@ -320,7 +320,7 @@ torch::Tensor ray_pdf_query(
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor resample_pdfs = torch::zeros({n_resamples}, pdfs.options()); torch::Tensor resample_pdfs = torch::zeros({n_resamples, 1}, pdfs.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pdfs.scalar_type(), pdfs.scalar_type(),
......
...@@ -196,11 +196,10 @@ def ray_marching( ...@@ -196,11 +196,10 @@ def ray_marching(
proposal_sample_list = [] proposal_sample_list = []
# resample with proposal nets # resample with proposal nets
for net, num_samples in zip(proposal_nets, [32]): for net, num_samples in zip(proposal_nets, [32]):
ray_indices = unpack_info(packed_info)
with torch.enable_grad(): with torch.enable_grad():
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net) sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net)
weights = render_weight_from_density( weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps=0 t_starts, t_ends, sigmas, ray_indices=ray_indices
) )
proposal_sample_list.append( proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights) (packed_info, t_starts, t_ends, weights)
...@@ -208,6 +207,7 @@ def ray_marching( ...@@ -208,6 +207,7 @@ def ray_marching(
packed_info, t_starts, t_ends = ray_resampling( packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=num_samples packed_info, t_starts, t_ends, weights, n_samples=num_samples
) )
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
# skip invisible space # skip invisible space
if sigma_fn is not None or alpha_fn is not None: if sigma_fn is not None or alpha_fn is not None:
...@@ -239,6 +239,6 @@ def ray_marching( ...@@ -239,6 +239,6 @@ def ray_marching(
) )
if proposal_nets is not None: if proposal_nets is not None:
return packed_info, t_starts, t_ends, proposal_sample_list return ray_indices, t_starts, t_ends, proposal_sample_list
else: else:
return packed_info, t_starts, t_ends return ray_indices, t_starts, t_ends
...@@ -23,7 +23,7 @@ def rendering( ...@@ -23,7 +23,7 @@ def rendering(
rgb_alpha_fn: Optional[Callable] = None, rgb_alpha_fn: Optional[Callable] = None,
# rendering options # rendering options
render_bkgd: Optional[torch.Tensor] = None, render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple:
"""Render the rays through the radience field defined by `rgb_sigma_fn`. """Render the rays through the radience field defined by `rgb_sigma_fn`.
This function is differentiable to the outputs of `rgb_sigma_fn` so it can This function is differentiable to the outputs of `rgb_sigma_fn` so it can
......
from struct import pack
import torch import torch
from torch import Tensor from torch import Tensor
from nerfacc import ray_marching, ray_resampling from nerfacc import pack_info, ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query from nerfacc.cuda import ray_pdf_query
device = "cuda:0" device = "cuda:0"
...@@ -44,13 +42,14 @@ def test_pdf_query(): ...@@ -44,13 +42,14 @@ def test_pdf_query():
rays_d = torch.randn((n_rays, 3), device=device) rays_d = torch.randn((n_rays, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_o,
rays_d, rays_d,
near_plane=0.1, near_plane=0.1,
far_plane=1.0, far_plane=1.0,
render_step_size=0.2, render_step_size=0.2,
) )
packed_info = pack_info(ray_indices, n_rays)
weights = torch.rand((t_starts.shape[0],), device=device) weights = torch.rand((t_starts.shape[0],), device=device)
packed_info_new = packed_info packed_info_new = packed_info
...@@ -76,10 +75,10 @@ def test_pdf_query(): ...@@ -76,10 +75,10 @@ def test_pdf_query():
t_ends_new, t_ends_new,
) )
weights_new = weights_new.flatten() weights_new = weights_new.flatten()
print(weights) # print(weights)
print(weights_new_ref) # print(weights_new_ref)
print(weights_new) # print(weights_new)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -124,7 +124,7 @@ def test_rendering(): ...@@ -124,7 +124,7 @@ def test_rendering():
t_starts = torch.rand_like(sigmas) t_starts = torch.rand_like(sigmas)
t_ends = torch.rand_like(sigmas) + 1.0 t_ends = torch.rand_like(sigmas) + 1.0
_, _, _ = rendering( _, _, _, _ = rendering(
t_starts, t_starts,
t_ends, t_ends,
ray_indices=ray_indices, ray_indices=ray_indices,
......
...@@ -134,13 +134,15 @@ def test_pdf_query(): ...@@ -134,13 +134,15 @@ def test_pdf_query():
rays_d = torch.randn((1, 3), device=device) rays_d = torch.randn((1, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True) rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching( ray_indices, t_starts, t_ends = ray_marching(
rays_o, rays_o,
rays_d, rays_d,
near_plane=0.1, near_plane=0.1,
far_plane=1.0, far_plane=1.0,
render_step_size=0.2, render_step_size=0.2,
) )
packed_info = pack_info(ray_indices, rays_o.shape[0])
weights = torch.rand((t_starts.shape[0],), device=device) weights = torch.rand((t_starts.shape[0],), device=device)
weights_new = ray_pdf_query( weights_new = ray_pdf_query(
packed_info, packed_info,
...@@ -152,6 +154,7 @@ def test_pdf_query(): ...@@ -152,6 +154,7 @@ def test_pdf_query():
t_ends + 0.3, t_ends + 0.3,
) )
if __name__ == "__main__": if __name__ == "__main__":
test_resampling() test_resampling()
test_pdf_query() test_pdf_query()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment