Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
96211bba
Commit
96211bba
authored
Sep 10, 2022
by
Ruilong Li
Browse files
wtf
parent
65bebd64
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
289 additions
and
249 deletions
+289
-249
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+7
-3
examples/trainval.py
examples/trainval.py
+71
-43
nerfacc/cuda/csrc/intersection.cu
nerfacc/cuda/csrc/intersection.cu
+5
-5
nerfacc/cuda/csrc/pybind.cu
nerfacc/cuda/csrc/pybind.cu
+1
-1
nerfacc/cuda/csrc/ray_gen.cu
nerfacc/cuda/csrc/ray_gen.cu
+0
-0
nerfacc/cuda/csrc/ray_marching.cu
nerfacc/cuda/csrc/ray_marching.cu
+160
-152
nerfacc/volumetric_rendering.py
nerfacc/volumetric_rendering.py
+45
-45
No files found.
examples/datasets/nerf_synthetic.py
View file @
96211bba
...
...
@@ -65,6 +65,7 @@ class SubjectLoader(torch.utils.data.Dataset):
WIDTH
,
HEIGHT
=
800
,
800
NEAR
,
FAR
=
2.0
,
6.0
OPENGL_CAMERA
=
True
def
__init__
(
self
,
...
...
@@ -186,15 +187,18 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs
=
F
.
pad
(
torch
.
stack
(
[
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
],
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
/
self
.
K
[
1
,
1
],
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
]
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
/
self
.
K
[
1
,
1
]
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
],
dim
=-
1
,
),
(
0
,
1
),
value
=
1
,
)
# [num_rays, 3]
camera_dirs
[...,
[
1
,
2
]]
*=
-
1
# opengl format
# [n_cams, height, width, 3]
directions
=
(
camera_dirs
[:,
None
,
:]
*
c2w
[:,
:
3
,
:
3
]).
sum
(
dim
=-
1
)
...
...
examples/trainval.py
View file @
96211bba
...
...
@@ -5,7 +5,7 @@ import numpy as np
import
torch
import
torch.nn.functional
as
F
import
tqdm
from
datasets.nerf_synthetic
import
SubjectLoader
,
namedtuple_map
from
datasets.nerf_synthetic
import
Rays
,
SubjectLoader
,
namedtuple_map
from
radiance_fields.ngp
import
NGPradianceField
from
nerfacc
import
OccupancyField
,
volumetric_rendering
...
...
@@ -67,10 +67,10 @@ if __name__ == "__main__":
# setup dataset
train_dataset
=
SubjectLoader
(
subject_id
=
"
lego
"
,
subject_id
=
"
mic
"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"train"
,
num_rays
=
4096
,
split
=
"train
val
"
,
num_rays
=
4096
00
,
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
...
...
@@ -85,7 +85,7 @@ if __name__ == "__main__":
)
test_dataset
=
SubjectLoader
(
subject_id
=
"
lego
"
,
subject_id
=
"
mic
"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"test"
,
num_rays
=
None
,
...
...
@@ -144,12 +144,27 @@ if __name__ == "__main__":
occ_eval_fn
=
occ_eval_fn
,
aabb
=
scene_aabb
,
resolution
=
128
).
to
(
device
)
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
# training
step
=
0
tic
=
time
.
time
()
data_time
=
0
tic_data
=
time
.
time
()
for
epoch
in
range
(
400
):
weights_image_ids
=
torch
.
ones
((
len
(
train_dataset
.
images
),),
device
=
device
)
weights_xs
=
torch
.
ones
(
(
train_dataset
.
WIDTH
,),
device
=
device
,
)
weights_ys
=
torch
.
ones
(
(
train_dataset
.
HEIGHT
,),
device
=
device
,
)
for
epoch
in
range
(
40000000
):
data
=
train_dataset
[
0
]
for
i
in
range
(
len
(
train_dataset
)):
data
=
train_dataset
[
i
]
data_time
+=
time
.
time
()
-
tic_data
...
...
@@ -162,53 +177,66 @@ if __name__ == "__main__":
pixels
=
data
[
"pixels"
].
to
(
device
)
render_bkgd
=
data
[
"color_bkgd"
].
to
(
device
)
# update occupancy grid
occ_field
.
every_n_step
(
step
)
#
#
update occupancy grid
#
occ_field.every_n_step(step)
rgb
,
depth
,
acc
,
alive_ray_mask
,
counter
,
compact_counter
=
render_image
(
radiance_field
,
rays
,
render_bkgd
render_est_n_samples
=
2
**
16
*
16
if
radiance_field
.
training
else
None
volumetric_rendering
(
query_fn
=
radiance_field
.
forward
,
# {x, dir} -> {rgb, density}
rays_o
=
rays
.
origins
,
rays_d
=
rays
.
viewdirs
,
scene_aabb
=
occ_field
.
aabb
,
scene_occ_binary
=
occ_field
.
occ_grid_binary
,
scene_resolution
=
occ_field
.
resolution
,
render_bkgd
=
render_bkgd
,
render_n_samples
=
render_n_samples
,
render_est_n_samples
=
render_est_n_samples
,
# memory control: wrost case
)
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
*
(
2
**
16
/
float
(
compact_counter
)))
num_rays
=
int
(
math
.
ceil
(
num_rays
/
128.0
)
*
128
)
train_dataset
.
update_num_rays
(
num_rays
)
# compute loss
loss
=
F
.
mse_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
# rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
# radiance_field, rays, render_bkgd
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer
.
zero_grad
()
(
loss
*
128.0
).
backward
()
optimizer
.
step
()
scheduler
.
step
()
#
optimizer.zero_grad()
#
(loss * 128.0).backward()
#
optimizer.step()
#
scheduler.step()
if
step
%
50
==
0
:
elapsed_time
=
time
.
time
()
-
tic
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s (data=
{
data_time
:.
2
f
}
s) |
{
step
=
}
| "
f
"loss=
{
loss
:.
5
f
}
| "
f
"alive_ray_mask=
{
alive_ray_mask
.
long
().
sum
():
d
}
| "
f
"counter=
{
counter
:
d
}
| compact_counter=
{
compact_counter
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
"
#
f"loss={loss:.5f} | "
#
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
#
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
)
if
step
%
35_000
==
0
and
step
>
0
:
# evaluation
radiance_field
.
eval
()
psnrs
=
[]
with
torch
.
no_grad
():
for
data
in
tqdm
.
tqdm
(
test_dataloader
):
# generate rays from data and the gt pixel color
rays
=
namedtuple_map
(
lambda
x
:
x
.
to
(
device
),
data
[
"rays"
])
pixels
=
data
[
"pixels"
].
to
(
device
)
render_bkgd
=
data
[
"color_bkgd"
].
to
(
device
)
# rendering
rgb
,
depth
,
acc
,
alive_ray_mask
,
_
,
_
=
render_image
(
radiance_field
,
rays
,
render_bkgd
)
mse
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
psnrs
.
append
(
psnr
.
item
())
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
print
(
f
"evaluation:
{
psnr_avg
=
}
"
)
#
if step % 35_000 == 0 and step > 0:
#
# evaluation
#
radiance_field.eval()
#
psnrs = []
#
with torch.no_grad():
#
for data in tqdm.tqdm(test_dataloader):
#
# generate rays from data and the gt pixel color
#
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
#
pixels = data["pixels"].to(device)
#
render_bkgd = data["color_bkgd"].to(device)
#
# rendering
#
rgb, depth, acc, alive_ray_mask, _, _ = render_image(
#
radiance_field, rays, render_bkgd
#
)
#
mse = F.mse_loss(rgb, pixels)
#
psnr = -10.0 * torch.log(mse) / np.log(10.0)
#
psnrs.append(psnr.item())
#
psnr_avg = sum(psnrs) / len(psnrs)
#
print(f"evaluation: {psnr_avg=}")
tic_data
=
time
.
time
()
step
+=
1
...
...
nerfacc/cuda/csrc/intersection.cu
View file @
96211bba
...
...
@@ -19,8 +19,8 @@ inline __host__ __device__ void _ray_aabb_intersect(
if
(
tymin
>
tymax
)
__swap
(
tymin
,
tymax
);
if
(
tmin
>
tymax
||
tymin
>
tmax
){
*
near
=
std
::
numeric_limits
<
scalar_t
>::
max
()
;
*
far
=
std
::
numeric_limits
<
scalar_t
>::
max
()
;
*
near
=
1e10
;
*
far
=
1e10
;
return
;
}
...
...
@@ -32,8 +32,8 @@ inline __host__ __device__ void _ray_aabb_intersect(
if
(
tzmin
>
tzmax
)
__swap
(
tzmin
,
tzmax
);
if
(
tmin
>
tzmax
||
tzmin
>
tmax
){
*
near
=
std
::
numeric_limits
<
scalar_t
>::
max
()
;
*
far
=
std
::
numeric_limits
<
scalar_t
>::
max
()
;
*
near
=
1e10
;
*
far
=
1e10
;
return
;
}
...
...
@@ -103,7 +103,7 @@ std::vector<torch::Tensor> ray_aabb_intersect(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
rays_o
.
scalar_type
(),
"ray_aabb_intersect"
,
([
&
]
{
kernel_ray_aabb_intersect
<
scalar_t
><<<
blocks
,
threads
>>>
(
kernel_ray_aabb_intersect
<
scalar_t
><<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
N
,
rays_o
.
data_ptr
<
scalar_t
>
(),
rays_d
.
data_ptr
<
scalar_t
>
(),
...
...
nerfacc/cuda/csrc/pybind.cu
View file @
96211bba
...
...
@@ -16,7 +16,7 @@ std::vector<torch::Tensor> ray_marching(
const
torch
::
Tensor
t_max
,
// density grid
const
torch
::
Tensor
aabb
,
const
torch
::
Tensor
resolution
,
const
pybind11
::
list
resolution
,
const
torch
::
Tensor
occ_binary
,
// sampling
const
int
max_total_samples
,
...
...
nerfacc/cuda/csrc/ray_gen.cu
0 → 100644
View file @
96211bba
nerfacc/cuda/csrc/ray_marching.cu
View file @
96211bba
#include <pybind11/pybind11.h>
#include "include/helpers_cuda.h"
inline
__device__
int
cascaded_grid_idx_at
(
const
float
x
,
const
float
y
,
const
float
z
,
const
int
*
resolution
,
const
float
*
aabb
const
int
resx
,
const
int
resy
,
const
int
resz
,
const
float
*
aabb
)
{
// TODO(ruilongli): if the x, y, z is outside the aabb, it will be clipped into aabb!!! We should just return false
int
ix
=
(
int
)(((
x
-
aabb
[
0
])
/
(
aabb
[
3
]
-
aabb
[
0
]))
*
res
olution
[
0
]
);
int
iy
=
(
int
)(((
y
-
aabb
[
1
])
/
(
aabb
[
4
]
-
aabb
[
1
]))
*
res
olution
[
1
]
);
int
iz
=
(
int
)(((
z
-
aabb
[
2
])
/
(
aabb
[
5
]
-
aabb
[
2
]))
*
res
olution
[
2
]
);
ix
=
__clamp
(
ix
,
0
,
res
olution
[
0
]
-
1
);
iy
=
__clamp
(
iy
,
0
,
res
olution
[
1
]
-
1
);
iz
=
__clamp
(
iz
,
0
,
res
olution
[
2
]
-
1
);
int
idx
=
ix
*
res
olution
[
1
]
*
resolution
[
2
]
+
iy
*
res
olution
[
2
]
+
iz
;
int
ix
=
(
int
)(((
x
-
aabb
[
0
])
/
(
aabb
[
3
]
-
aabb
[
0
]))
*
res
x
);
int
iy
=
(
int
)(((
y
-
aabb
[
1
])
/
(
aabb
[
4
]
-
aabb
[
1
]))
*
res
y
);
int
iz
=
(
int
)(((
z
-
aabb
[
2
])
/
(
aabb
[
5
]
-
aabb
[
2
]))
*
res
z
);
ix
=
__clamp
(
ix
,
0
,
res
x
-
1
);
iy
=
__clamp
(
iy
,
0
,
res
y
-
1
);
iz
=
__clamp
(
iz
,
0
,
res
z
-
1
);
int
idx
=
ix
*
res
x
*
resy
+
iy
*
res
z
+
iz
;
return
idx
;
}
inline
__device__
bool
grid_occupied_at
(
const
float
x
,
const
float
y
,
const
float
z
,
const
int
*
resolution
,
const
float
*
aabb
,
const
bool
*
occ_binary
const
int
resx
,
const
int
resy
,
const
int
resz
,
const
float
*
aabb
,
const
bool
*
occ_binary
)
{
int
idx
=
cascaded_grid_idx_at
(
x
,
y
,
z
,
res
olution
,
aabb
);
int
idx
=
cascaded_grid_idx_at
(
x
,
y
,
z
,
res
x
,
resy
,
resz
,
aabb
);
return
occ_binary
[
idx
];
}
...
...
@@ -28,13 +31,13 @@ inline __device__ float distance_to_next_voxel(
float
x
,
float
y
,
float
z
,
float
dir_x
,
float
dir_y
,
float
dir_z
,
float
idir_x
,
float
idir_y
,
float
idir_z
,
const
int
*
res
olution
const
int
res
x
,
const
int
resy
,
const
int
resz
)
{
// dda like step
// TODO: warning: expression has no effect?
x
,
y
,
z
=
res
olution
[
0
]
*
x
,
resolution
[
1
]
*
y
,
resolution
[
2
]
*
z
;
float
tx
=
((
floorf
(
x
+
0.5
f
+
0.5
f
*
__sign
(
dir_x
))
-
x
)
*
idir_x
)
/
res
olution
[
0
]
;
float
ty
=
((
floorf
(
y
+
0.5
f
+
0.5
f
*
__sign
(
dir_y
))
-
y
)
*
idir_y
)
/
res
olution
[
1
]
;
float
tz
=
((
floorf
(
z
+
0.5
f
+
0.5
f
*
__sign
(
dir_z
))
-
z
)
*
idir_z
)
/
res
olution
[
2
]
;
x
,
y
,
z
=
res
x
*
x
,
resy
*
y
,
resz
*
z
;
float
tx
=
((
floorf
(
x
+
0.5
f
+
0.5
f
*
__sign
(
dir_x
))
-
x
)
*
idir_x
)
/
res
x
;
float
ty
=
((
floorf
(
y
+
0.5
f
+
0.5
f
*
__sign
(
dir_y
))
-
y
)
*
idir_y
)
/
res
y
;
float
tz
=
((
floorf
(
z
+
0.5
f
+
0.5
f
*
__sign
(
dir_z
))
-
z
)
*
idir_z
)
/
res
z
;
float
t
=
min
(
min
(
tx
,
ty
),
tz
);
return
fmaxf
(
t
,
0.0
f
);
}
...
...
@@ -44,10 +47,11 @@ inline __device__ float advance_to_next_voxel(
float
x
,
float
y
,
float
z
,
float
dir_x
,
float
dir_y
,
float
dir_z
,
float
idir_x
,
float
idir_y
,
float
idir_z
,
const
int
*
resolution
,
float
dt_min
)
{
const
int
resx
,
const
int
resy
,
const
int
resz
,
float
dt_min
)
{
// Regular stepping (may be slower but matches non-empty space)
float
t_target
=
t
+
distance_to_next_voxel
(
x
,
y
,
z
,
dir_x
,
dir_y
,
dir_z
,
idir_x
,
idir_y
,
idir_z
,
res
olution
x
,
y
,
z
,
dir_x
,
dir_y
,
dir_z
,
idir_x
,
idir_y
,
idir_z
,
res
x
,
resy
,
resz
);
do
{
t
+=
dt_min
;
...
...
@@ -65,7 +69,9 @@ __global__ void kernel_raymarching(
const
float
*
t_max
,
// shape (n_rays,)
// density grid
const
float
*
aabb
,
// [min_x, min_y, min_z, max_x, max_y, max_y]
const
int
*
resolution
,
// [reso_x, reso_y, reso_z]
const
int
resx
,
const
int
resy
,
const
int
resz
,
const
bool
*
occ_binary
,
// shape (reso_x, reso_y, reso_z)
// sampling
const
int
max_total_samples
,
...
...
@@ -83,102 +89,102 @@ __global__ void kernel_raymarching(
)
{
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
// locate
rays_o
+=
i
*
3
;
rays_d
+=
i
*
3
;
t_min
+=
i
;
t_max
+=
i
;
//
//
locate
//
rays_o += i * 3;
//
rays_d += i * 3;
//
t_min += i;
//
t_max += i;
const
float
ox
=
rays_o
[
0
],
oy
=
rays_o
[
1
],
oz
=
rays_o
[
2
];
const
float
dx
=
rays_d
[
0
],
dy
=
rays_d
[
1
],
dz
=
rays_d
[
2
];
const
float
rdx
=
1
/
dx
,
rdy
=
1
/
dy
,
rdz
=
1
/
dz
;
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
//
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
//
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
//
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
//
const float near = t_min[0], far = t_max[0];
uint32_t
ray_idx
,
base
,
marching_samples
;
uint32_t
j
;
float
t0
,
t1
,
t_mid
;
//
uint32_t ray_idx, base, marching_samples;
//
uint32_t j;
//
float t0, t1, t_mid;
// first pass to compute an accurate number of steps
j
=
0
;
t0
=
near
;
// TODO(ruilongli): perturb `near` as in ngp_pl?
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
//
//
first pass to compute an accurate number of steps
//
j = 0;
//
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) * 0.5f;
while
(
t_mid
<
far
&&
j
<
max_per_ray_samples
)
{
// current center
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
y
=
oy
+
t_mid
*
dy
;
const
float
z
=
oz
+
t_mid
*
dz
;
//
while (t_mid < far && j < max_per_ray_samples) {
//
// current center
//
const float x = ox + t_mid * dx;
//
const float y = oy + t_mid * dy;
//
const float z = oz + t_mid * dz;
if
(
grid_occupied_at
(
x
,
y
,
z
,
res
olution
,
aabb
,
occ_binary
))
{
++
j
;
// march to next sample
t0
=
t1
;
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
}
else
{
// march to next sample
t_mid
=
advance_to_next_voxel
(
t_mid
,
x
,
y
,
z
,
dx
,
dy
,
dz
,
rdx
,
rdy
,
rdz
,
res
olution
,
dt
);
t0
=
t_mid
-
dt
*
0.5
f
;
t1
=
t_mid
+
dt
*
0.5
f
;
}
}
if
(
j
==
0
)
return
;
//
if (grid_occupied_at(x, y, z, res
x, resy, resz
, aabb, occ_binary)) {
//
++j;
//
// march to next sample
//
t0 = t1;
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) * 0.5f;
//
}
//
else {
//
// march to next sample
//
t_mid = advance_to_next_voxel(
//
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, res
x, resy, resz
, dt
//
);
//
t0 = t_mid - dt * 0.5f;
//
t1 = t_mid + dt * 0.5f;
//
}
//
}
//
if (j == 0) return;
marching_samples
=
j
;
base
=
atomicAdd
(
steps_counter
,
marching_samples
);
if
(
base
+
marching_samples
>
max_total_samples
)
return
;
ray_idx
=
atomicAdd
(
rays_counter
,
1
);
//
marching_samples = j;
//
base = atomicAdd(steps_counter, marching_samples);
//
if (base + marching_samples > max_total_samples) return;
//
ray_idx = atomicAdd(rays_counter, 1);
// locate
frustum_origins
+=
base
*
3
;
frustum_dirs
+=
base
*
3
;
frustum_starts
+=
base
;
frustum_ends
+=
base
;
//
//
locate
//
frustum_origins += base * 3;
//
frustum_dirs += base * 3;
//
frustum_starts += base;
//
frustum_ends += base;
// Second round
j
=
0
;
t0
=
near
;
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
/
2.
;
//
//
Second round
//
j = 0;
//
t0 = near;
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) / 2.;
while
(
t_mid
<
far
&&
j
<
marching_samples
)
{
// current center
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
y
=
oy
+
t_mid
*
dy
;
const
float
z
=
oz
+
t_mid
*
dz
;
//
while (t_mid < far && j < marching_samples) {
//
// current center
//
const float x = ox + t_mid * dx;
//
const float y = oy + t_mid * dy;
//
const float z = oz + t_mid * dz;
if
(
grid_occupied_at
(
x
,
y
,
z
,
res
olution
,
aabb
,
occ_binary
))
{
frustum_origins
[
j
*
3
+
0
]
=
ox
;
frustum_origins
[
j
*
3
+
1
]
=
oy
;
frustum_origins
[
j
*
3
+
2
]
=
oz
;
frustum_dirs
[
j
*
3
+
0
]
=
dx
;
frustum_dirs
[
j
*
3
+
1
]
=
dy
;
frustum_dirs
[
j
*
3
+
2
]
=
dz
;
frustum_starts
[
j
]
=
t0
;
frustum_ends
[
j
]
=
t1
;
++
j
;
// march to next sample
t0
=
t1
;
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
}
else
{
// march to next sample
t_mid
=
advance_to_next_voxel
(
t_mid
,
x
,
y
,
z
,
dx
,
dy
,
dz
,
rdx
,
rdy
,
rdz
,
res
olution
,
dt
);
t0
=
t_mid
-
dt
*
0.5
f
;
t1
=
t_mid
+
dt
*
0.5
f
;
}
}
//
if (grid_occupied_at(x, y, z, res
x, resy, resz
, aabb, occ_binary)) {
//
frustum_origins[j * 3 + 0] = ox;
//
frustum_origins[j * 3 + 1] = oy;
//
frustum_origins[j * 3 + 2] = oz;
//
frustum_dirs[j * 3 + 0] = dx;
//
frustum_dirs[j * 3 + 1] = dy;
//
frustum_dirs[j * 3 + 2] = dz;
//
frustum_starts[j] = t0;
//
frustum_ends[j] = t1;
//
++j;
//
// march to next sample
//
t0 = t1;
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) * 0.5f;
//
}
//
else {
//
// march to next sample
//
t_mid = advance_to_next_voxel(
//
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, res
x, resy, resz
, dt
//
);
//
t0 = t_mid - dt * 0.5f;
//
t1 = t_mid + dt * 0.5f;
//
}
//
}
packed_info
[
ray_idx
*
3
+
0
]
=
i
;
// ray idx in {rays_o, rays_d}
packed_info
[
ray_idx
*
3
+
1
]
=
base
;
// point idx start.
packed_info
[
ray_idx
*
3
+
2
]
=
j
;
// point idx shift (actual marching samples).
//
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
//
packed_info[ray_idx * 3 + 1] = base; // point idx start.
//
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return
;
}
...
...
@@ -220,67 +226,69 @@ std::vector<torch::Tensor> ray_marching(
const
torch
::
Tensor
t_max
,
// density grid
const
torch
::
Tensor
aabb
,
const
torch
::
Tensor
resolution
,
const
pybind11
::
list
resolution
,
const
torch
::
Tensor
occ_binary
,
// sampling
const
int
max_total_samples
,
const
int
max_per_ray_samples
,
const
float
dt
)
{
DEVICE_GUARD
(
rays_o
);
//
DEVICE_GUARD(rays_o);
CHECK_INPUT
(
rays_o
);
CHECK_INPUT
(
rays_d
);
CHECK_INPUT
(
t_min
);
CHECK_INPUT
(
t_max
);
CHECK_INPUT
(
aabb
);
CHECK_INPUT
(
resolution
);
CHECK_INPUT
(
occ_binary
);
// CHECK_INPUT(rays_o);
// CHECK_INPUT(rays_d);
// CHECK_INPUT(t_min);
// CHECK_INPUT(t_max);
// CHECK_INPUT(aabb);
// CHECK_INPUT(occ_binary);
const
int
n_rays
=
rays_o
.
size
(
0
);
//
const int n_rays = rays_o.size(0);
const
int
threads
=
256
;
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
// //
const int threads = 256;
// //
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch
::
Tensor
steps_counter
=
torch
::
zeros
(
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
torch
::
Tensor
rays_counter
=
torch
::
zeros
(
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
//
//
helper counter
//
torch::Tensor steps_counter = torch::zeros(
//
{1}, rays_o.options().dtype(torch::kInt32));
//
torch::Tensor rays_counter = torch::zeros(
//
{1}, rays_o.options().dtype(torch::kInt32));
// output frustum samples
torch
::
Tensor
packed_info
=
torch
::
zeros
(
{
n_rays
,
3
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// ray_id, sample_id, num_samples
torch
::
Tensor
frustum_origins
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_dirs
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_starts
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
torch
::
Tensor
frustum_ends
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
//
//
output frustum samples
//
torch::Tensor packed_info = torch::zeros(
//
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
//
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
//
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
//
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
//
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
kernel_raymarching
<<<
blocks
,
threads
>>>
(
// rays
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// density grid
aabb
.
data_ptr
<
float
>
(),
resolution
.
data_ptr
<
int
>
(),
occ_binary
.
data_ptr
<
bool
>
(),
// sampling
max_total_samples
,
max_per_ray_samples
,
dt
,
// writable helpers
steps_counter
.
data_ptr
<
int
>
(),
// total samples.
rays_counter
.
data_ptr
<
int
>
(),
// total rays.
packed_info
.
data_ptr
<
int
>
(),
frustum_origins
.
data_ptr
<
float
>
(),
frustum_dirs
.
data_ptr
<
float
>
(),
frustum_starts
.
data_ptr
<
float
>
(),
frustum_ends
.
data_ptr
<
float
>
()
);
// kernel_raymarching<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// // rays
// n_rays,
// rays_o.data_ptr<float>(),
// rays_d.data_ptr<float>(),
// t_min.data_ptr<float>(),
// t_max.data_ptr<float>(),
// // density grid
// aabb.data_ptr<float>(),
// resolution[0].cast<int>(),
// resolution[1].cast<int>(),
// resolution[2].cast<int>(),
// occ_binary.data_ptr<bool>(),
// // sampling
// max_total_samples,
// max_per_ray_samples,
// dt,
// // writable helpers
// steps_counter.data_ptr<int>(), // total samples.
// rays_counter.data_ptr<int>(), // total rays.
// packed_info.data_ptr<int>(),
// frustum_origins.data_ptr<float>(),
// frustum_dirs.data_ptr<float>(),
// frustum_starts.data_ptr<float>(),
// frustum_ends.data_ptr<float>()
// );
return
{
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
steps_counter
};
// return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return
{};
}
nerfacc/volumetric_rendering.py
View file @
96211bba
...
...
@@ -22,7 +22,8 @@ def volumetric_rendering(
device
=
rays_o
.
device
if
render_bkgd
is
None
:
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
scene_resolution
=
torch
.
tensor
(
scene_resolution
,
dtype
=
torch
.
int
,
device
=
device
)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o
=
rays_o
.
contiguous
()
rays_d
=
rays_d
.
contiguous
()
...
...
@@ -40,18 +41,17 @@ def volumetric_rendering(
)
with
torch
.
no_grad
():
# TODO: avoid clamp here. kinda stupid
t_min
,
t_max
=
ray_aabb_intersect
(
rays_o
,
rays_d
,
scene_aabb
)
t_min
=
torch
.
clamp
(
t_min
,
max
=
1e10
)
t_max
=
torch
.
clamp
(
t_max
,
max
=
1e10
)
#
t_min = torch.clamp(t_min, max=1e10)
#
t_max = torch.clamp(t_max, max=1e10)
(
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
steps_counter
,
#
packed_info,
#
frustum_origins,
#
frustum_dirs,
#
frustum_starts,
#
frustum_ends,
#
steps_counter,
)
=
ray_marching
(
# rays
rays_o
,
...
...
@@ -68,43 +68,43 @@ def volumetric_rendering(
render_step_size
,
)
# squeeze valid samples
total_samples
=
max
(
packed_info
[:,
-
1
].
sum
(),
1
)
total_samples
=
int
(
math
.
ceil
(
total_samples
/
128.0
))
*
128
frustum_origins
=
frustum_origins
[:
total_samples
]
frustum_dirs
=
frustum_dirs
[:
total_samples
]
frustum_starts
=
frustum_starts
[:
total_samples
]
frustum_ends
=
frustum_ends
[:
total_samples
]
#
# squeeze valid samples
#
total_samples = max(packed_info[:, -1].sum(), 1)
#
total_samples = int(math.ceil(total_samples / 128.0)) * 128
#
frustum_origins = frustum_origins[:total_samples]
#
frustum_dirs = frustum_dirs[:total_samples]
#
frustum_starts = frustum_starts[:total_samples]
#
frustum_ends = frustum_ends[:total_samples]
frustum_positions
=
(
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
)
#
frustum_positions = (
#
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
#
)
query_results
=
query_fn
(
frustum_positions
,
frustum_dirs
,
**
kwargs
)
rgbs
,
densities
=
query_results
[
0
],
query_results
[
1
]
#
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
#
rgbs, densities = query_results[0], query_results[1]
(
accumulated_weight
,
accumulated_depth
,
accumulated_color
,
alive_ray_mask
,
compact_steps_counter
,
)
=
VolumeRenderer
.
apply
(
packed_info
,
frustum_starts
,
frustum_ends
,
densities
.
contiguous
(),
rgbs
.
contiguous
(),
)
#
(
#
accumulated_weight,
#
accumulated_depth,
#
accumulated_color,
#
alive_ray_mask,
#
compact_steps_counter,
#
) = VolumeRenderer.apply(
#
packed_info,
#
frustum_starts,
#
frustum_ends,
#
densities.contiguous(),
#
rgbs.contiguous(),
#
)
accumulated_depth
=
torch
.
clip
(
accumulated_depth
,
t_min
[:,
None
],
t_max
[:,
None
])
accumulated_color
=
accumulated_color
+
render_bkgd
*
(
1.0
-
accumulated_weight
)
#
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
#
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return
(
accumulated_color
,
accumulated_depth
,
accumulated_weight
,
alive_ray_mask
,
steps_counter
,
compact_steps_counter
,
)
#
return (
#
accumulated_color,
#
accumulated_depth,
#
accumulated_weight,
#
alive_ray_mask,
#
steps_counter,
#
compact_steps_counter,
#
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment