"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "89b23d986958bd9597a5c397bde0bdaca6b9134f"
Unverified Commit 4b7819b4 authored by arterms's avatar arterms Committed by GitHub
Browse files

Fixed grid traversal with near and far planes (#204)

* Fixed traversal with far_plane

* test added for traversal with near and far planes

* More correct test

* black formatting for test
parent d858e9f3
...@@ -140,6 +140,7 @@ __global__ void traverse_grids_kernel( ...@@ -140,6 +140,7 @@ __global__ void traverse_grids_kernel(
const int3 overflow_index = final_index + step_index; const int3 overflow_index = final_index + step_index;
while (true) { while (true) {
float t_traverse = min(tdist.x, min(tdist.y, tdist.z)); float t_traverse = min(tdist.x, min(tdist.y, tdist.z));
t_traverse = fminf(t_traverse, this_tmax);
int64_t cell_id = ( int64_t cell_id = (
current_index.x * resolution.y * resolution.z current_index.x * resolution.y * resolution.z
+ current_index.y * resolution.z + current_index.y * resolution.z
......
...@@ -68,6 +68,35 @@ def test_traverse_grids(): ...@@ -68,6 +68,35 @@ def test_traverse_grids():
assert selector.all(), selector.float().mean() assert selector.all(), selector.float().mean()
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_traverse_grids_with_near_far_planes():
from nerfacc.grid import traverse_grids
rays_o = torch.tensor([[-1.0, 0.0, 0.0]], device=device)
rays_d = torch.tensor([[1.0, 0.01, 0.01]], device=device)
rays_d = rays_d / rays_d.norm(dim=-1, keepdim=True)
binaries = torch.ones((1, 1, 1, 1), dtype=torch.bool, device=device)
aabbs = torch.tensor([[0.0, 0.0, 0.0, 1.0, 1.0, 1.0]], device=device)
near_planes = torch.tensor([1.2], device=device)
far_planes = torch.tensor([1.5], device=device)
step_size = 0.05
intervals, samples = traverse_grids(
rays_o=rays_o,
rays_d=rays_d,
binaries=binaries,
aabbs=aabbs,
step_size=step_size,
near_planes=near_planes,
far_planes=far_planes,
)
assert (intervals.vals >= (near_planes - step_size / 2)).all()
assert (intervals.vals <= (far_planes + step_size / 2)).all()
if __name__ == "__main__": if __name__ == "__main__":
test_ray_aabb_intersect() test_ray_aabb_intersect()
test_traverse_grids() test_traverse_grids()
test_traverse_grids_with_near_far_planes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment