Commit bc62c5b6 authored by Ruilong Li's avatar Ruilong Li
Browse files

merge and fix tests

parent 52512811
......@@ -256,7 +256,7 @@ std::vector<torch::Tensor> ray_resampling(
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(weights.ndimension() == 1);
TORCH_CHECK(weights.ndimension() == 2 & weights.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = weights.size(0);
......@@ -312,7 +312,7 @@ torch::Tensor ray_pdf_query(
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(pdfs.ndimension() == 1);
TORCH_CHECK(pdfs.ndimension() == 2 & pdfs.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_resamples = resample_starts.size(0);
......@@ -320,7 +320,7 @@ torch::Tensor ray_pdf_query(
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor resample_pdfs = torch::zeros({n_resamples}, pdfs.options());
torch::Tensor resample_pdfs = torch::zeros({n_resamples, 1}, pdfs.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
pdfs.scalar_type(),
......
......@@ -196,11 +196,10 @@ def ray_marching(
proposal_sample_list = []
# resample with proposal nets
for net, num_samples in zip(proposal_nets, [32]):
ray_indices = unpack_info(packed_info)
with torch.enable_grad():
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net)
weights = render_weight_from_density(
packed_info, t_starts, t_ends, sigmas, early_stop_eps=0
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights)
......@@ -208,6 +207,7 @@ def ray_marching(
packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=num_samples
)
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
......@@ -239,6 +239,6 @@ def ray_marching(
)
if proposal_nets is not None:
return packed_info, t_starts, t_ends, proposal_sample_list
return ray_indices, t_starts, t_ends, proposal_sample_list
else:
return packed_info, t_starts, t_ends
return ray_indices, t_starts, t_ends
......@@ -23,7 +23,7 @@ def rendering(
rgb_alpha_fn: Optional[Callable] = None,
# rendering options
render_bkgd: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple:
"""Render the rays through the radience field defined by `rgb_sigma_fn`.
This function is differentiable to the outputs of `rgb_sigma_fn` so it can
......
from struct import pack
import torch
from torch import Tensor
from nerfacc import ray_marching, ray_resampling
from nerfacc import pack_info, ray_marching, ray_resampling
from nerfacc.cuda import ray_pdf_query
device = "cuda:0"
......@@ -44,13 +42,14 @@ def test_pdf_query():
rays_d = torch.randn((n_rays, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching(
ray_indices, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
)
packed_info = pack_info(ray_indices, n_rays)
weights = torch.rand((t_starts.shape[0],), device=device)
packed_info_new = packed_info
......@@ -76,10 +75,10 @@ def test_pdf_query():
t_ends_new,
)
weights_new = weights_new.flatten()
print(weights)
# print(weights)
print(weights_new_ref)
print(weights_new)
# print(weights_new_ref)
# print(weights_new)
if __name__ == "__main__":
......
......@@ -124,7 +124,7 @@ def test_rendering():
t_starts = torch.rand_like(sigmas)
t_ends = torch.rand_like(sigmas) + 1.0
_, _, _ = rendering(
_, _, _, _ = rendering(
t_starts,
t_ends,
ray_indices=ray_indices,
......
......@@ -134,13 +134,15 @@ def test_pdf_query():
rays_d = torch.randn((1, 3), device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
packed_info, t_starts, t_ends = ray_marching(
ray_indices, t_starts, t_ends = ray_marching(
rays_o,
rays_d,
near_plane=0.1,
far_plane=1.0,
render_step_size=0.2,
)
packed_info = pack_info(ray_indices, rays_o.shape[0])
weights = torch.rand((t_starts.shape[0],), device=device)
weights_new = ray_pdf_query(
packed_info,
......@@ -152,6 +154,7 @@ def test_pdf_query():
t_ends + 0.3,
)
if __name__ == "__main__":
test_resampling()
test_pdf_query()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment