Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
16324602
Commit
16324602
authored
Sep 11, 2022
by
Ruilong Li
Browse files
benchmark
parent
96211bba
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
257 additions
and
301 deletions
+257
-301
README.md
README.md
+6
-0
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+2
-4
examples/radiance_fields/ngp.py
examples/radiance_fields/ngp.py
+0
-3
examples/trainval.py
examples/trainval.py
+60
-102
nerfacc/cuda/csrc/ray_marching.cu
nerfacc/cuda/csrc/ray_marching.cu
+136
-136
nerfacc/occupancy_field.py
nerfacc/occupancy_field.py
+7
-7
nerfacc/volumetric_rendering.py
nerfacc/volumetric_rendering.py
+46
-49
No files found.
README.md
View file @
16324602
...
...
@@ -10,6 +10,12 @@ python examples/trainval.py
## Performance Reference
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - |
| Time | 377s | 357s | 354s |
| PSNR | 36.08 | 36.58 | 29.63 |
Tested with the default settings on the Lego test set.
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
...
...
examples/datasets/nerf_synthetic.py
View file @
16324602
...
...
@@ -187,9 +187,7 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs
=
F
.
pad
(
torch
.
stack
(
[
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
]
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
],
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
/
self
.
K
[
1
,
1
]
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
...
...
@@ -197,7 +195,7 @@ class SubjectLoader(torch.utils.data.Dataset):
dim
=-
1
,
),
(
0
,
1
),
value
=
1
,
value
=
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
)
,
)
# [num_rays, 3]
# [n_cams, height, width, 3]
...
...
examples/radiance_fields/ngp.py
View file @
16324602
...
...
@@ -98,7 +98,6 @@ class NGPradianceField(BaseRadianceField):
},
)
@
torch
.
cuda
.
amp
.
autocast
()
def
query_density
(
self
,
x
,
return_feat
:
bool
=
False
):
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
x
=
(
x
-
bb_min
)
/
(
bb_max
-
bb_min
)
...
...
@@ -119,7 +118,6 @@ class NGPradianceField(BaseRadianceField):
else
:
return
density
@
torch
.
cuda
.
amp
.
autocast
()
def
_query_rgb
(
self
,
dir
,
embedding
):
# tcnn requires directions in the range [0, 1]
if
self
.
use_viewdirs
:
...
...
@@ -131,7 +129,6 @@ class NGPradianceField(BaseRadianceField):
rgb
=
self
.
mlp_head
(
h
).
view
(
list
(
embedding
.
shape
[:
-
1
])
+
[
3
]).
to
(
embedding
)
return
rgb
@
torch
.
cuda
.
amp
.
autocast
()
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
examples/trainval.py
View file @
16324602
...
...
@@ -5,13 +5,15 @@ import numpy as np
import
torch
import
torch.nn.functional
as
F
import
tqdm
from
datasets.nerf_synthetic
import
Rays
,
SubjectLoader
,
namedtuple_map
from
datasets.nerf_synthetic
import
SubjectLoader
,
namedtuple_map
from
radiance_fields.ngp
import
NGPradianceField
from
nerfacc
import
OccupancyField
,
volumetric_rendering
TARGET_SAMPLE_BATCH_SIZE
=
1
<<
16
def
render_image
(
radiance_field
,
rays
,
render_bkgd
):
def
render_image
(
radiance_field
,
rays
,
render_bkgd
,
render_step_size
):
"""Render the pixels of an image.
Args:
...
...
@@ -32,7 +34,9 @@ def render_image(radiance_field, rays, render_bkgd):
num_rays
,
_
=
rays_shape
results
=
[]
chunk
=
torch
.
iinfo
(
torch
.
int32
).
max
if
radiance_field
.
training
else
81920
render_est_n_samples
=
2
**
16
*
16
if
radiance_field
.
training
else
None
render_est_n_samples
=
(
TARGET_SAMPLE_BATCH_SIZE
*
16
if
radiance_field
.
training
else
None
)
for
i
in
range
(
0
,
num_rays
,
chunk
):
chunk_rays
=
namedtuple_map
(
lambda
r
:
r
[
i
:
i
+
chunk
],
rays
)
chunk_results
=
volumetric_rendering
(
...
...
@@ -45,6 +49,7 @@ def render_image(radiance_field, rays, render_bkgd):
render_bkgd
=
render_bkgd
,
render_n_samples
=
render_n_samples
,
render_est_n_samples
=
render_est_n_samples
,
# memory control: wrost case
render_step_size
=
render_step_size
,
)
results
.
append
(
chunk_results
)
rgb
,
depth
,
acc
,
alive_ray_mask
,
counter
,
compact_counter
=
[
...
...
@@ -64,13 +69,14 @@ if __name__ == "__main__":
torch
.
manual_seed
(
42
)
device
=
"cuda:0"
scene
=
"lego"
# setup dataset
train_dataset
=
SubjectLoader
(
subject_id
=
"mic"
,
subject_id
=
scene
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"trainval"
,
num_rays
=
4096
00
,
num_rays
=
4096
,
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
...
...
@@ -85,7 +91,7 @@ if __name__ == "__main__":
)
test_dataset
=
SubjectLoader
(
subject_id
=
"mic"
,
subject_id
=
scene
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"test"
,
num_rays
=
None
,
...
...
@@ -112,7 +118,7 @@ if __name__ == "__main__":
render_n_samples
=
1024
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
)
.
item
()
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
...
...
@@ -144,123 +150,75 @@ if __name__ == "__main__":
occ_eval_fn
=
occ_eval_fn
,
aabb
=
scene_aabb
,
resolution
=
128
).
to
(
device
)
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
# training
step
=
0
tic
=
time
.
time
()
data_time
=
0
tic_data
=
time
.
time
()
weights_image_ids
=
torch
.
ones
((
len
(
train_dataset
.
images
),),
device
=
device
)
weights_xs
=
torch
.
ones
(
(
train_dataset
.
WIDTH
,),
device
=
device
,
)
weights_ys
=
torch
.
ones
(
(
train_dataset
.
HEIGHT
,),
device
=
device
,
)
for
epoch
in
range
(
40000000
):
data
=
train_dataset
[
0
]
for
epoch
in
range
(
10000000
):
for
i
in
range
(
len
(
train_dataset
)):
data
=
train_dataset
[
i
]
data_time
+=
time
.
time
()
-
tic_data
if
step
>
35_000
:
print
(
"training stops"
)
exit
()
# generate rays from data and the gt pixel color
rays
=
namedtuple_map
(
lambda
x
:
x
.
to
(
device
),
data
[
"rays"
])
pixels
=
data
[
"pixels"
].
to
(
device
)
render_bkgd
=
data
[
"color_bkgd"
].
to
(
device
)
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
#
#
update occupancy grid
#
occ_field.every_n_step(step)
# update occupancy grid
occ_field
.
every_n_step
(
step
)
render_est_n_samples
=
2
**
16
*
16
if
radiance_field
.
training
else
None
volumetric_rendering
(
query_fn
=
radiance_field
.
forward
,
# {x, dir} -> {rgb, density}
rays_o
=
rays
.
origins
,
rays_d
=
rays
.
viewdirs
,
scene_aabb
=
occ_field
.
aabb
,
scene_occ_binary
=
occ_field
.
occ_grid_binary
,
scene_resolution
=
occ_field
.
resolution
,
render_bkgd
=
render_bkgd
,
render_n_samples
=
render_n_samples
,
render_est_n_samples
=
render_est_n_samples
,
# memory control: wrost case
rgb
,
depth
,
acc
,
alive_ray_mask
,
counter
,
compact_counter
=
render_image
(
radiance_field
,
rays
,
render_bkgd
,
render_step_size
)
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
*
(
TARGET_SAMPLE_BATCH_SIZE
/
float
(
compact_counter
.
item
()))
)
train_dataset
.
update_num_rays
(
num_rays
)
# rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
# radiance_field, rays, render_bkgd
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# compute loss
loss
=
F
.
mse_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
#
optimizer.zero_grad()
#
(loss * 128
.0
).backward()
#
optimizer.step()
#
scheduler.step()
optimizer
.
zero_grad
()
(
loss
*
128
).
backward
()
optimizer
.
step
()
scheduler
.
step
()
if
step
%
5
0
==
0
:
if
step
%
10
0
==
0
:
elapsed_time
=
time
.
time
()
-
tic
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s (data=
{
data_time
:.
2
f
}
s) |
{
step
=
}
| "
#
f"loss={loss:.5f} | "
#
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
#
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
f
"loss=
{
loss
:.
5
f
}
| "
f
"alive_ray_mask=
{
alive_ray_mask
.
long
().
sum
():
d
}
| "
f
"counter=
{
counter
.
item
()
:
d
}
| compact_counter=
{
compact_counter
.
item
()
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
"
)
# if step % 35_000 == 0 and step > 0:
# # evaluation
# radiance_field.eval()
# psnrs = []
# with torch.no_grad():
# for data in tqdm.tqdm(test_dataloader):
# # generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
# render_bkgd = data["color_bkgd"].to(device)
# # rendering
# rgb, depth, acc, alive_ray_mask, _, _ = render_image(
# radiance_field, rays, render_bkgd
# )
# mse = F.mse_loss(rgb, pixels)
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
# psnrs.append(psnr.item())
# psnr_avg = sum(psnrs) / len(psnrs)
# print(f"evaluation: {psnr_avg=}")
# if time.time() - tic > 300:
if
step
==
35_000
:
print
(
"training stops"
)
# evaluation
radiance_field
.
eval
()
psnrs
=
[]
with
torch
.
no_grad
():
for
data
in
tqdm
.
tqdm
(
test_dataloader
):
# generate rays from data and the gt pixel color
rays
=
namedtuple_map
(
lambda
x
:
x
.
to
(
device
),
data
[
"rays"
])
pixels
=
data
[
"pixels"
].
to
(
device
)
render_bkgd
=
data
[
"color_bkgd"
].
to
(
device
)
# rendering
rgb
,
depth
,
acc
,
alive_ray_mask
,
_
,
_
=
render_image
(
radiance_field
,
rays
,
render_bkgd
,
render_step_size
)
mse
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
psnrs
.
append
(
psnr
.
item
())
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
print
(
f
"evaluation:
{
psnr_avg
=
}
"
)
exit
()
tic_data
=
time
.
time
()
step
+=
1
# "train"
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
# "trainval" batch_over_images=True, schedule 2**18
# evaluation: psnr_avg=36.24 (6.75 it/s)
nerfacc/cuda/csrc/ray_marching.cu
View file @
16324602
...
...
@@ -14,7 +14,8 @@ inline __device__ int cascaded_grid_idx_at(
ix
=
__clamp
(
ix
,
0
,
resx
-
1
);
iy
=
__clamp
(
iy
,
0
,
resy
-
1
);
iz
=
__clamp
(
iz
,
0
,
resz
-
1
);
int
idx
=
ix
*
resx
*
resy
+
iy
*
resz
+
iz
;
int
idx
=
ix
*
resy
*
resz
+
iy
*
resz
+
iz
;
// printf("(ix, iy, iz) = (%d, %d, %d)\n", ix, iy, iz);
return
idx
;
}
...
...
@@ -89,102 +90,102 @@ __global__ void kernel_raymarching(
)
{
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
//
//
locate
//
rays_o += i * 3;
//
rays_d += i * 3;
//
t_min += i;
//
t_max += i;
// locate
rays_o
+=
i
*
3
;
rays_d
+=
i
*
3
;
t_min
+=
i
;
t_max
+=
i
;
//
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
//
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
//
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
//
const float near = t_min[0], far = t_max[0];
const
float
ox
=
rays_o
[
0
],
oy
=
rays_o
[
1
],
oz
=
rays_o
[
2
];
const
float
dx
=
rays_d
[
0
],
dy
=
rays_d
[
1
],
dz
=
rays_d
[
2
];
const
float
rdx
=
1
/
dx
,
rdy
=
1
/
dy
,
rdz
=
1
/
dz
;
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
//
uint32_t ray_idx, base, marching_samples;
//
uint32_t j;
//
float t0, t1, t_mid;
uint32_t
ray_idx
,
base
,
marching_samples
;
uint32_t
j
;
float
t0
,
t1
,
t_mid
;
//
//
first pass to compute an accurate number of steps
//
j = 0;
//
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) * 0.5f;
// first pass to compute an accurate number of steps
j
=
0
;
t0
=
near
;
// TODO(ruilongli): perturb `near` as in ngp_pl?
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
//
while (t_mid < far && j < max_per_ray_samples) {
//
// current center
//
const float x = ox + t_mid * dx;
//
const float y = oy + t_mid * dy;
//
const float z = oz + t_mid * dz;
while
(
t_mid
<
far
&&
j
<
max_per_ray_samples
)
{
// current center
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
y
=
oy
+
t_mid
*
dy
;
const
float
z
=
oz
+
t_mid
*
dz
;
//
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
//
++j;
//
// march to next sample
//
t0 = t1;
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) * 0.5f;
//
}
//
else {
//
// march to next sample
//
t_mid = advance_to_next_voxel(
//
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
//
);
//
t0 = t_mid - dt * 0.5f;
//
t1 = t_mid + dt * 0.5f;
//
}
//
}
//
if (j == 0) return;
if
(
grid_occupied_at
(
x
,
y
,
z
,
resx
,
resy
,
resz
,
aabb
,
occ_binary
))
{
++
j
;
// march to next sample
t0
=
t1
;
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
}
else
{
// march to next sample
t_mid
=
advance_to_next_voxel
(
t_mid
,
x
,
y
,
z
,
dx
,
dy
,
dz
,
rdx
,
rdy
,
rdz
,
resx
,
resy
,
resz
,
dt
);
t0
=
t_mid
-
dt
*
0.5
f
;
t1
=
t_mid
+
dt
*
0.5
f
;
}
}
if
(
j
==
0
)
return
;
//
marching_samples = j;
//
base = atomicAdd(steps_counter, marching_samples);
//
if (base + marching_samples > max_total_samples) return;
//
ray_idx = atomicAdd(rays_counter, 1);
marching_samples
=
j
;
base
=
atomicAdd
(
steps_counter
,
marching_samples
);
if
(
base
+
marching_samples
>
max_total_samples
)
return
;
ray_idx
=
atomicAdd
(
rays_counter
,
1
);
//
//
locate
//
frustum_origins += base * 3;
//
frustum_dirs += base * 3;
//
frustum_starts += base;
//
frustum_ends += base;
// locate
frustum_origins
+=
base
*
3
;
frustum_dirs
+=
base
*
3
;
frustum_starts
+=
base
;
frustum_ends
+=
base
;
//
//
Second round
//
j = 0;
//
t0 = near;
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) / 2.;
// Second round
j
=
0
;
t0
=
near
;
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
/
2.
;
//
while (t_mid < far && j < marching_samples) {
//
// current center
//
const float x = ox + t_mid * dx;
//
const float y = oy + t_mid * dy;
//
const float z = oz + t_mid * dz;
while
(
t_mid
<
far
&&
j
<
marching_samples
)
{
// current center
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
y
=
oy
+
t_mid
*
dy
;
const
float
z
=
oz
+
t_mid
*
dz
;
//
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
//
frustum_origins[j * 3 + 0] = ox;
//
frustum_origins[j * 3 + 1] = oy;
//
frustum_origins[j * 3 + 2] = oz;
//
frustum_dirs[j * 3 + 0] = dx;
//
frustum_dirs[j * 3 + 1] = dy;
//
frustum_dirs[j * 3 + 2] = dz;
//
frustum_starts[j] = t0;
//
frustum_ends[j] = t1;
//
++j;
//
// march to next sample
//
t0 = t1;
//
t1 = t0 + dt;
//
t_mid = (t0 + t1) * 0.5f;
//
}
//
else {
//
// march to next sample
//
t_mid = advance_to_next_voxel(
//
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
//
);
//
t0 = t_mid - dt * 0.5f;
//
t1 = t_mid + dt * 0.5f;
//
}
//
}
if
(
grid_occupied_at
(
x
,
y
,
z
,
resx
,
resy
,
resz
,
aabb
,
occ_binary
))
{
frustum_origins
[
j
*
3
+
0
]
=
ox
;
frustum_origins
[
j
*
3
+
1
]
=
oy
;
frustum_origins
[
j
*
3
+
2
]
=
oz
;
frustum_dirs
[
j
*
3
+
0
]
=
dx
;
frustum_dirs
[
j
*
3
+
1
]
=
dy
;
frustum_dirs
[
j
*
3
+
2
]
=
dz
;
frustum_starts
[
j
]
=
t0
;
frustum_ends
[
j
]
=
t1
;
++
j
;
// march to next sample
t0
=
t1
;
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
}
else
{
// march to next sample
t_mid
=
advance_to_next_voxel
(
t_mid
,
x
,
y
,
z
,
dx
,
dy
,
dz
,
rdx
,
rdy
,
rdz
,
resx
,
resy
,
resz
,
dt
);
t0
=
t_mid
-
dt
*
0.5
f
;
t1
=
t_mid
+
dt
*
0.5
f
;
}
}
//
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
//
packed_info[ray_idx * 3 + 1] = base; // point idx start.
//
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
packed_info
[
ray_idx
*
3
+
0
]
=
i
;
// ray idx in {rays_o, rays_d}
packed_info
[
ray_idx
*
3
+
1
]
=
base
;
// point idx start.
packed_info
[
ray_idx
*
3
+
2
]
=
j
;
// point idx shift (actual marching samples).
return
;
}
...
...
@@ -233,62 +234,61 @@ std::vector<torch::Tensor> ray_marching(
const
int
max_per_ray_samples
,
const
float
dt
)
{
//
DEVICE_GUARD(rays_o);
DEVICE_GUARD
(
rays_o
);
//
CHECK_INPUT(rays_o);
//
CHECK_INPUT(rays_d);
//
CHECK_INPUT(t_min);
//
CHECK_INPUT(t_max);
//
CHECK_INPUT(aabb);
//
CHECK_INPUT(occ_binary);
CHECK_INPUT
(
rays_o
);
CHECK_INPUT
(
rays_d
);
CHECK_INPUT
(
t_min
);
CHECK_INPUT
(
t_max
);
CHECK_INPUT
(
aabb
);
CHECK_INPUT
(
occ_binary
);
//
const int n_rays = rays_o.size(0);
const
int
n_rays
=
rays_o
.
size
(
0
);
// //
const int threads = 256;
// //
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
const
int
threads
=
256
;
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
//
//
helper counter
//
torch::Tensor steps_counter = torch::zeros(
//
{1}, rays_o.options().dtype(torch::kInt32));
//
torch::Tensor rays_counter = torch::zeros(
//
{1}, rays_o.options().dtype(torch::kInt32));
// helper counter
torch
::
Tensor
steps_counter
=
torch
::
zeros
(
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
torch
::
Tensor
rays_counter
=
torch
::
zeros
(
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
//
//
output frustum samples
//
torch::Tensor packed_info = torch::zeros(
//
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
//
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
//
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
//
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
//
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
// output frustum samples
torch
::
Tensor
packed_info
=
torch
::
zeros
(
{
n_rays
,
3
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// ray_id, sample_id, num_samples
torch
::
Tensor
frustum_origins
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_dirs
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_starts
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
torch
::
Tensor
frustum_ends
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
//
kernel_raymarching<<<blocks, threads
, 0, at::cuda::getCurrentCUDAStream()
>>>(
//
// rays
//
n_rays,
//
rays_o.data_ptr<float>(),
//
rays_d.data_ptr<float>(),
//
t_min.data_ptr<float>(),
//
t_max.data_ptr<float>(),
//
// density grid
//
aabb.data_ptr<float>(),
//
resolution[0].cast<int>(),
//
resolution[1].cast<int>(),
//
resolution[2].cast<int>(),
//
occ_binary.data_ptr<bool>(),
//
// sampling
//
max_total_samples,
//
max_per_ray_samples,
//
dt,
//
// writable helpers
//
steps_counter.data_ptr<int>(), // total samples.
//
rays_counter.data_ptr<int>(), // total rays.
//
packed_info.data_ptr<int>(),
//
frustum_origins.data_ptr<float>(),
//
frustum_dirs.data_ptr<float>(),
//
frustum_starts.data_ptr<float>(),
//
frustum_ends.data_ptr<float>()
//
);
kernel_raymarching
<<<
blocks
,
threads
>>>
(
// rays
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// density grid
aabb
.
data_ptr
<
float
>
(),
resolution
[
0
].
cast
<
int
>
(),
resolution
[
1
].
cast
<
int
>
(),
resolution
[
2
].
cast
<
int
>
(),
occ_binary
.
data_ptr
<
bool
>
(),
// sampling
max_total_samples
,
max_per_ray_samples
,
dt
,
// writable helpers
steps_counter
.
data_ptr
<
int
>
(),
// total samples.
rays_counter
.
data_ptr
<
int
>
(),
// total rays.
packed_info
.
data_ptr
<
int
>
(),
frustum_origins
.
data_ptr
<
float
>
(),
frustum_dirs
.
data_ptr
<
float
>
(),
frustum_starts
.
data_ptr
<
float
>
(),
frustum_ends
.
data_ptr
<
float
>
()
);
// return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return
{};
return
{
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
steps_counter
};
}
nerfacc/occupancy_field.py
View file @
16324602
...
...
@@ -72,6 +72,7 @@ class OccupancyField(nn.Module):
self
.
register_buffer
(
"aabb"
,
aabb
)
self
.
resolution
=
resolution
self
.
register_buffer
(
"resolution_tensor"
,
torch
.
tensor
(
resolution
))
self
.
num_dim
=
num_dim
self
.
num_cells
=
torch
.
tensor
(
resolution
).
prod
().
item
()
...
...
@@ -107,7 +108,6 @@ class OccupancyField(nn.Module):
if
n
<
len
(
occupied_indices
):
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
device
)
occupied_indices
=
occupied_indices
[
selector
]
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
return
indices
...
...
@@ -129,19 +129,19 @@ class OccupancyField(nn.Module):
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
"""
resolution
=
torch
.
tensor
(
self
.
resolution
).
to
(
self
.
occ_grid
.
device
)
# sample cells
if
step
<
warmup_steps
:
indices
=
self
.
_get_all_cells
()
else
:
N
=
resolution
.
prod
().
item
()
//
4
N
=
self
.
num_cells
//
4
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
# infer occupancy: density * step_size
tmp_occ_grid
=
-
torch
.
ones_like
(
self
.
occ_grid
)
grid_coords
=
self
.
grid_coords
[
indices
]
x
=
(
grid_coords
+
torch
.
rand_like
(
grid_coords
.
float
()))
/
resolution
x
=
(
grid_coords
+
torch
.
rand_like
(
grid_coords
.
float
())
)
/
self
.
resolution_tensor
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
x
=
x
*
(
bb_max
-
bb_min
)
+
bb_min
tmp_occ_grid
[
indices
]
=
self
.
occ_eval_fn
(
x
).
squeeze
(
-
1
)
...
...
@@ -152,8 +152,8 @@ class OccupancyField(nn.Module):
self
.
occ_grid
[
ema_mask
]
*
ema_decay
,
tmp_occ_grid
[
ema_mask
]
)
self
.
occ_grid_mean
=
self
.
occ_grid
.
mean
()
self
.
occ_grid_binary
=
self
.
occ_grid
>
min
(
self
.
occ_grid_mean
.
item
(),
occ_threshold
self
.
occ_grid_binary
=
self
.
occ_grid
>
torch
.
clamp
(
self
.
occ_grid_mean
,
max
=
occ_threshold
)
@
torch
.
no_grad
()
...
...
nerfacc/volumetric_rendering.py
View file @
16324602
...
...
@@ -16,6 +16,7 @@ def volumetric_rendering(
render_bkgd
:
torch
.
Tensor
=
None
,
render_n_samples
:
int
=
1024
,
render_est_n_samples
:
int
=
None
,
render_step_size
:
int
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""A *fast* version of differentiable volumetric rendering."""
...
...
@@ -23,8 +24,6 @@ def volumetric_rendering(
if
render_bkgd
is
None
:
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o
=
rays_o
.
contiguous
()
rays_d
=
rays_d
.
contiguous
()
scene_aabb
=
scene_aabb
.
contiguous
()
...
...
@@ -36,22 +35,22 @@ def volumetric_rendering(
render_total_samples
=
n_rays
*
render_n_samples
else
:
render_total_samples
=
render_est_n_samples
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
if
render_step_size
is
None
:
# Note: CPU<->GPU is not idea, try to pre-define it outside this function.
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
with
torch
.
no_grad
():
t_min
,
t_max
=
ray_aabb_intersect
(
rays_o
,
rays_d
,
scene_aabb
)
# t_min = torch.clamp(t_min, max=1e10)
# t_max = torch.clamp(t_max, max=1e10)
(
#
packed_info,
#
frustum_origins,
#
frustum_dirs,
#
frustum_starts,
#
frustum_ends,
#
steps_counter,
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
steps_counter
,
)
=
ray_marching
(
# rays
rays_o
,
...
...
@@ -68,43 +67,41 @@ def volumetric_rendering(
render_step_size
,
)
# # squeeze valid samples
# total_samples = max(packed_info[:, -1].sum(), 1)
# total_samples = int(math.ceil(total_samples / 128.0)) * 128
# frustum_origins = frustum_origins[:total_samples]
# frustum_dirs = frustum_dirs[:total_samples]
# frustum_starts = frustum_starts[:total_samples]
# frustum_ends = frustum_ends[:total_samples]
# frustum_positions = (
# frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
# )
# squeeze valid samples
total_samples
=
max
(
packed_info
[:,
-
1
].
sum
(),
1
)
frustum_origins
=
frustum_origins
[:
total_samples
]
frustum_dirs
=
frustum_dirs
[:
total_samples
]
frustum_starts
=
frustum_starts
[:
total_samples
]
frustum_ends
=
frustum_ends
[:
total_samples
]
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
# rgbs, densities = query_results[0], query_results[1]
frustum_positions
=
(
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
)
# (
# accumulated_weight,
# accumulated_depth,
# accumulated_color,
# alive_ray_mask,
# compact_steps_counter,
# ) = VolumeRenderer.apply(
# packed_info,
# frustum_starts,
# frustum_ends,
# densities.contiguous(),
# rgbs.contiguous(),
# )
query_results
=
query_fn
(
frustum_positions
,
frustum_dirs
,
**
kwargs
)
rgbs
,
densities
=
query_results
[
0
],
query_results
[
1
]
(
accumulated_weight
,
accumulated_depth
,
accumulated_color
,
alive_ray_mask
,
compact_steps_counter
,
)
=
VolumeRenderer
.
apply
(
packed_info
,
frustum_starts
,
frustum_ends
,
densities
.
contiguous
(),
rgbs
.
contiguous
(),
)
#
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
#
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
accumulated_depth
=
torch
.
clip
(
accumulated_depth
,
t_min
[:,
None
],
t_max
[:,
None
])
accumulated_color
=
accumulated_color
+
render_bkgd
*
(
1.0
-
accumulated_weight
)
#
return (
#
accumulated_color,
#
accumulated_depth,
#
accumulated_weight,
#
alive_ray_mask,
#
steps_counter,
#
compact_steps_counter,
#
)
return
(
accumulated_color
,
accumulated_depth
,
accumulated_weight
,
alive_ray_mask
,
steps_counter
,
compact_steps_counter
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment