Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
16324602
Commit
16324602
authored
Sep 11, 2022
by
Ruilong Li
Browse files
benchmark
parent
96211bba
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
257 additions
and
301 deletions
+257
-301
README.md
README.md
+6
-0
examples/datasets/nerf_synthetic.py
examples/datasets/nerf_synthetic.py
+2
-4
examples/radiance_fields/ngp.py
examples/radiance_fields/ngp.py
+0
-3
examples/trainval.py
examples/trainval.py
+60
-102
nerfacc/cuda/csrc/ray_marching.cu
nerfacc/cuda/csrc/ray_marching.cu
+136
-136
nerfacc/occupancy_field.py
nerfacc/occupancy_field.py
+7
-7
nerfacc/volumetric_rendering.py
nerfacc/volumetric_rendering.py
+46
-49
No files found.
README.md
View file @
16324602
...
@@ -10,6 +10,12 @@ python examples/trainval.py
...
@@ -10,6 +10,12 @@ python examples/trainval.py
## Performance Reference
## Performance Reference
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - |
| Time | 377s | 357s | 354s |
| PSNR | 36.08 | 36.58 | 29.63 |
Tested with the default settings on the Lego test set.
Tested with the default settings on the Lego test set.
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
...
...
examples/datasets/nerf_synthetic.py
View file @
16324602
...
@@ -187,9 +187,7 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -187,9 +187,7 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs
=
F
.
pad
(
camera_dirs
=
F
.
pad
(
torch
.
stack
(
torch
.
stack
(
[
[
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
(
x
-
self
.
K
[
0
,
2
]
+
0.5
)
/
self
.
K
[
0
,
0
],
/
self
.
K
[
0
,
0
]
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
(
y
-
self
.
K
[
1
,
2
]
+
0.5
)
/
self
.
K
[
1
,
1
]
/
self
.
K
[
1
,
1
]
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
*
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
),
...
@@ -197,7 +195,7 @@ class SubjectLoader(torch.utils.data.Dataset):
...
@@ -197,7 +195,7 @@ class SubjectLoader(torch.utils.data.Dataset):
dim
=-
1
,
dim
=-
1
,
),
),
(
0
,
1
),
(
0
,
1
),
value
=
1
,
value
=
(
-
1.0
if
self
.
OPENGL_CAMERA
else
1.0
)
,
)
# [num_rays, 3]
)
# [num_rays, 3]
# [n_cams, height, width, 3]
# [n_cams, height, width, 3]
...
...
examples/radiance_fields/ngp.py
View file @
16324602
...
@@ -98,7 +98,6 @@ class NGPradianceField(BaseRadianceField):
...
@@ -98,7 +98,6 @@ class NGPradianceField(BaseRadianceField):
},
},
)
)
@
torch
.
cuda
.
amp
.
autocast
()
def
query_density
(
self
,
x
,
return_feat
:
bool
=
False
):
def
query_density
(
self
,
x
,
return_feat
:
bool
=
False
):
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
x
=
(
x
-
bb_min
)
/
(
bb_max
-
bb_min
)
x
=
(
x
-
bb_min
)
/
(
bb_max
-
bb_min
)
...
@@ -119,7 +118,6 @@ class NGPradianceField(BaseRadianceField):
...
@@ -119,7 +118,6 @@ class NGPradianceField(BaseRadianceField):
else
:
else
:
return
density
return
density
@
torch
.
cuda
.
amp
.
autocast
()
def
_query_rgb
(
self
,
dir
,
embedding
):
def
_query_rgb
(
self
,
dir
,
embedding
):
# tcnn requires directions in the range [0, 1]
# tcnn requires directions in the range [0, 1]
if
self
.
use_viewdirs
:
if
self
.
use_viewdirs
:
...
@@ -131,7 +129,6 @@ class NGPradianceField(BaseRadianceField):
...
@@ -131,7 +129,6 @@ class NGPradianceField(BaseRadianceField):
rgb
=
self
.
mlp_head
(
h
).
view
(
list
(
embedding
.
shape
[:
-
1
])
+
[
3
]).
to
(
embedding
)
rgb
=
self
.
mlp_head
(
h
).
view
(
list
(
embedding
.
shape
[:
-
1
])
+
[
3
]).
to
(
embedding
)
return
rgb
return
rgb
@
torch
.
cuda
.
amp
.
autocast
()
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
examples/trainval.py
View file @
16324602
...
@@ -5,13 +5,15 @@ import numpy as np
...
@@ -5,13 +5,15 @@ import numpy as np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
tqdm
import
tqdm
from
datasets.nerf_synthetic
import
Rays
,
SubjectLoader
,
namedtuple_map
from
datasets.nerf_synthetic
import
SubjectLoader
,
namedtuple_map
from
radiance_fields.ngp
import
NGPradianceField
from
radiance_fields.ngp
import
NGPradianceField
from
nerfacc
import
OccupancyField
,
volumetric_rendering
from
nerfacc
import
OccupancyField
,
volumetric_rendering
TARGET_SAMPLE_BATCH_SIZE
=
1
<<
16
def
render_image
(
radiance_field
,
rays
,
render_bkgd
):
def
render_image
(
radiance_field
,
rays
,
render_bkgd
,
render_step_size
):
"""Render the pixels of an image.
"""Render the pixels of an image.
Args:
Args:
...
@@ -32,7 +34,9 @@ def render_image(radiance_field, rays, render_bkgd):
...
@@ -32,7 +34,9 @@ def render_image(radiance_field, rays, render_bkgd):
num_rays
,
_
=
rays_shape
num_rays
,
_
=
rays_shape
results
=
[]
results
=
[]
chunk
=
torch
.
iinfo
(
torch
.
int32
).
max
if
radiance_field
.
training
else
81920
chunk
=
torch
.
iinfo
(
torch
.
int32
).
max
if
radiance_field
.
training
else
81920
render_est_n_samples
=
2
**
16
*
16
if
radiance_field
.
training
else
None
render_est_n_samples
=
(
TARGET_SAMPLE_BATCH_SIZE
*
16
if
radiance_field
.
training
else
None
)
for
i
in
range
(
0
,
num_rays
,
chunk
):
for
i
in
range
(
0
,
num_rays
,
chunk
):
chunk_rays
=
namedtuple_map
(
lambda
r
:
r
[
i
:
i
+
chunk
],
rays
)
chunk_rays
=
namedtuple_map
(
lambda
r
:
r
[
i
:
i
+
chunk
],
rays
)
chunk_results
=
volumetric_rendering
(
chunk_results
=
volumetric_rendering
(
...
@@ -45,6 +49,7 @@ def render_image(radiance_field, rays, render_bkgd):
...
@@ -45,6 +49,7 @@ def render_image(radiance_field, rays, render_bkgd):
render_bkgd
=
render_bkgd
,
render_bkgd
=
render_bkgd
,
render_n_samples
=
render_n_samples
,
render_n_samples
=
render_n_samples
,
render_est_n_samples
=
render_est_n_samples
,
# memory control: wrost case
render_est_n_samples
=
render_est_n_samples
,
# memory control: wrost case
render_step_size
=
render_step_size
,
)
)
results
.
append
(
chunk_results
)
results
.
append
(
chunk_results
)
rgb
,
depth
,
acc
,
alive_ray_mask
,
counter
,
compact_counter
=
[
rgb
,
depth
,
acc
,
alive_ray_mask
,
counter
,
compact_counter
=
[
...
@@ -64,13 +69,14 @@ if __name__ == "__main__":
...
@@ -64,13 +69,14 @@ if __name__ == "__main__":
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
device
=
"cuda:0"
device
=
"cuda:0"
scene
=
"lego"
# setup dataset
# setup dataset
train_dataset
=
SubjectLoader
(
train_dataset
=
SubjectLoader
(
subject_id
=
"mic"
,
subject_id
=
scene
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"trainval"
,
split
=
"trainval"
,
num_rays
=
4096
00
,
num_rays
=
4096
,
)
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
...
@@ -85,7 +91,7 @@ if __name__ == "__main__":
...
@@ -85,7 +91,7 @@ if __name__ == "__main__":
)
)
test_dataset
=
SubjectLoader
(
test_dataset
=
SubjectLoader
(
subject_id
=
"mic"
,
subject_id
=
scene
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"test"
,
split
=
"test"
,
num_rays
=
None
,
num_rays
=
None
,
...
@@ -112,7 +118,7 @@ if __name__ == "__main__":
...
@@ -112,7 +118,7 @@ if __name__ == "__main__":
render_n_samples
=
1024
render_n_samples
=
1024
render_step_size
=
(
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
)
.
item
()
optimizer
=
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
radiance_field
.
parameters
(),
radiance_field
.
parameters
(),
...
@@ -144,123 +150,75 @@ if __name__ == "__main__":
...
@@ -144,123 +150,75 @@ if __name__ == "__main__":
occ_eval_fn
=
occ_eval_fn
,
aabb
=
scene_aabb
,
resolution
=
128
occ_eval_fn
=
occ_eval_fn
,
aabb
=
scene_aabb
,
resolution
=
128
).
to
(
device
)
).
to
(
device
)
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
# training
# training
step
=
0
step
=
0
tic
=
time
.
time
()
tic
=
time
.
time
()
data_time
=
0
data_time
=
0
tic_data
=
time
.
time
()
tic_data
=
time
.
time
()
weights_image_ids
=
torch
.
ones
((
len
(
train_dataset
.
images
),),
device
=
device
)
for
epoch
in
range
(
10000000
):
weights_xs
=
torch
.
ones
(
(
train_dataset
.
WIDTH
,),
device
=
device
,
)
weights_ys
=
torch
.
ones
(
(
train_dataset
.
HEIGHT
,),
device
=
device
,
)
for
epoch
in
range
(
40000000
):
data
=
train_dataset
[
0
]
for
i
in
range
(
len
(
train_dataset
)):
for
i
in
range
(
len
(
train_dataset
)):
data
=
train_dataset
[
i
]
data
=
train_dataset
[
i
]
data_time
+=
time
.
time
()
-
tic_data
data_time
+=
time
.
time
()
-
tic_data
if
step
>
35_000
:
print
(
"training stops"
)
exit
()
# generate rays from data and the gt pixel color
# generate rays from data and the gt pixel color
rays
=
namedtuple_map
(
lambda
x
:
x
.
to
(
device
),
data
[
"rays"
])
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels
=
data
[
"pixels"
].
to
(
device
)
# pixels = data["pixels"].to(device)
render_bkgd
=
data
[
"color_bkgd"
].
to
(
device
)
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
#
#
update occupancy grid
# update occupancy grid
#
occ_field.every_n_step(step)
occ_field
.
every_n_step
(
step
)
render_est_n_samples
=
2
**
16
*
16
if
radiance_field
.
training
else
None
rgb
,
depth
,
acc
,
alive_ray_mask
,
counter
,
compact_counter
=
render_image
(
volumetric_rendering
(
radiance_field
,
rays
,
render_bkgd
,
render_step_size
query_fn
=
radiance_field
.
forward
,
# {x, dir} -> {rgb, density}
rays_o
=
rays
.
origins
,
rays_d
=
rays
.
viewdirs
,
scene_aabb
=
occ_field
.
aabb
,
scene_occ_binary
=
occ_field
.
occ_grid_binary
,
scene_resolution
=
occ_field
.
resolution
,
render_bkgd
=
render_bkgd
,
render_n_samples
=
render_n_samples
,
render_est_n_samples
=
render_est_n_samples
,
# memory control: wrost case
)
)
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
*
(
TARGET_SAMPLE_BATCH_SIZE
/
float
(
compact_counter
.
item
()))
)
train_dataset
.
update_num_rays
(
num_rays
)
# rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
# compute loss
# radiance_field, rays, render_bkgd
loss
=
F
.
mse_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
#
optimizer.zero_grad()
optimizer
.
zero_grad
()
#
(loss * 128
.0
).backward()
(
loss
*
128
).
backward
()
#
optimizer.step()
optimizer
.
step
()
#
scheduler.step()
scheduler
.
step
()
if
step
%
5
0
==
0
:
if
step
%
10
0
==
0
:
elapsed_time
=
time
.
time
()
-
tic
elapsed_time
=
time
.
time
()
-
tic
print
(
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s (data=
{
data_time
:.
2
f
}
s) |
{
step
=
}
| "
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s (data=
{
data_time
:.
2
f
}
s) |
{
step
=
}
| "
#
f"loss={loss:.5f} | "
f
"loss=
{
loss
:.
5
f
}
| "
#
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f
"alive_ray_mask=
{
alive_ray_mask
.
long
().
sum
():
d
}
| "
#
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
f
"counter=
{
counter
.
item
()
:
d
}
| compact_counter=
{
compact_counter
.
item
()
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
"
)
)
# if step % 35_000 == 0 and step > 0:
# if time.time() - tic > 300:
# # evaluation
if
step
==
35_000
:
# radiance_field.eval()
print
(
"training stops"
)
# psnrs = []
# evaluation
# with torch.no_grad():
radiance_field
.
eval
()
# for data in tqdm.tqdm(test_dataloader):
psnrs
=
[]
# # generate rays from data and the gt pixel color
with
torch
.
no_grad
():
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
for
data
in
tqdm
.
tqdm
(
test_dataloader
):
# pixels = data["pixels"].to(device)
# generate rays from data and the gt pixel color
# render_bkgd = data["color_bkgd"].to(device)
rays
=
namedtuple_map
(
lambda
x
:
x
.
to
(
device
),
data
[
"rays"
])
# # rendering
pixels
=
data
[
"pixels"
].
to
(
device
)
# rgb, depth, acc, alive_ray_mask, _, _ = render_image(
render_bkgd
=
data
[
"color_bkgd"
].
to
(
device
)
# radiance_field, rays, render_bkgd
# rendering
# )
rgb
,
depth
,
acc
,
alive_ray_mask
,
_
,
_
=
render_image
(
# mse = F.mse_loss(rgb, pixels)
radiance_field
,
rays
,
render_bkgd
,
render_step_size
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
)
# psnrs.append(psnr.item())
mse
=
F
.
mse_loss
(
rgb
,
pixels
)
# psnr_avg = sum(psnrs) / len(psnrs)
psnr
=
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
# print(f"evaluation: {psnr_avg=}")
psnrs
.
append
(
psnr
.
item
())
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
print
(
f
"evaluation:
{
psnr_avg
=
}
"
)
exit
()
tic_data
=
time
.
time
()
tic_data
=
time
.
time
()
step
+=
1
step
+=
1
# "train"
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
# "trainval" batch_over_images=True, schedule 2**18
# evaluation: psnr_avg=36.24 (6.75 it/s)
nerfacc/cuda/csrc/ray_marching.cu
View file @
16324602
...
@@ -14,7 +14,8 @@ inline __device__ int cascaded_grid_idx_at(
...
@@ -14,7 +14,8 @@ inline __device__ int cascaded_grid_idx_at(
ix
=
__clamp
(
ix
,
0
,
resx
-
1
);
ix
=
__clamp
(
ix
,
0
,
resx
-
1
);
iy
=
__clamp
(
iy
,
0
,
resy
-
1
);
iy
=
__clamp
(
iy
,
0
,
resy
-
1
);
iz
=
__clamp
(
iz
,
0
,
resz
-
1
);
iz
=
__clamp
(
iz
,
0
,
resz
-
1
);
int
idx
=
ix
*
resx
*
resy
+
iy
*
resz
+
iz
;
int
idx
=
ix
*
resy
*
resz
+
iy
*
resz
+
iz
;
// printf("(ix, iy, iz) = (%d, %d, %d)\n", ix, iy, iz);
return
idx
;
return
idx
;
}
}
...
@@ -89,102 +90,102 @@ __global__ void kernel_raymarching(
...
@@ -89,102 +90,102 @@ __global__ void kernel_raymarching(
)
{
)
{
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
//
//
locate
// locate
//
rays_o += i * 3;
rays_o
+=
i
*
3
;
//
rays_d += i * 3;
rays_d
+=
i
*
3
;
//
t_min += i;
t_min
+=
i
;
//
t_max += i;
t_max
+=
i
;
//
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const
float
ox
=
rays_o
[
0
],
oy
=
rays_o
[
1
],
oz
=
rays_o
[
2
];
//
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const
float
dx
=
rays_d
[
0
],
dy
=
rays_d
[
1
],
dz
=
rays_d
[
2
];
//
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const
float
rdx
=
1
/
dx
,
rdy
=
1
/
dy
,
rdz
=
1
/
dz
;
//
const float near = t_min[0], far = t_max[0];
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
//
uint32_t ray_idx, base, marching_samples;
uint32_t
ray_idx
,
base
,
marching_samples
;
//
uint32_t j;
uint32_t
j
;
//
float t0, t1, t_mid;
float
t0
,
t1
,
t_mid
;
//
//
first pass to compute an accurate number of steps
// first pass to compute an accurate number of steps
//
j = 0;
j
=
0
;
//
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t0
=
near
;
// TODO(ruilongli): perturb `near` as in ngp_pl?
//
t1 = t0 + dt;
t1
=
t0
+
dt
;
//
t_mid = (t0 + t1) * 0.5f;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
//
while (t_mid < far && j < max_per_ray_samples) {
while
(
t_mid
<
far
&&
j
<
max_per_ray_samples
)
{
//
// current center
// current center
//
const float x = ox + t_mid * dx;
const
float
x
=
ox
+
t_mid
*
dx
;
//
const float y = oy + t_mid * dy;
const
float
y
=
oy
+
t_mid
*
dy
;
//
const float z = oz + t_mid * dz;
const
float
z
=
oz
+
t_mid
*
dz
;
//
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
if
(
grid_occupied_at
(
x
,
y
,
z
,
resx
,
resy
,
resz
,
aabb
,
occ_binary
))
{
//
++j;
++
j
;
//
// march to next sample
// march to next sample
//
t0 = t1;
t0
=
t1
;
//
t1 = t0 + dt;
t1
=
t0
+
dt
;
//
t_mid = (t0 + t1) * 0.5f;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
//
}
}
//
else {
else
{
//
// march to next sample
// march to next sample
//
t_mid = advance_to_next_voxel(
t_mid
=
advance_to_next_voxel
(
//
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
t_mid
,
x
,
y
,
z
,
dx
,
dy
,
dz
,
rdx
,
rdy
,
rdz
,
resx
,
resy
,
resz
,
dt
//
);
);
//
t0 = t_mid - dt * 0.5f;
t0
=
t_mid
-
dt
*
0.5
f
;
//
t1 = t_mid + dt * 0.5f;
t1
=
t_mid
+
dt
*
0.5
f
;
//
}
}
//
}
}
//
if (j == 0) return;
if
(
j
==
0
)
return
;
//
marching_samples = j;
marching_samples
=
j
;
//
base = atomicAdd(steps_counter, marching_samples);
base
=
atomicAdd
(
steps_counter
,
marching_samples
);
//
if (base + marching_samples > max_total_samples) return;
if
(
base
+
marching_samples
>
max_total_samples
)
return
;
//
ray_idx = atomicAdd(rays_counter, 1);
ray_idx
=
atomicAdd
(
rays_counter
,
1
);
//
//
locate
// locate
//
frustum_origins += base * 3;
frustum_origins
+=
base
*
3
;
//
frustum_dirs += base * 3;
frustum_dirs
+=
base
*
3
;
//
frustum_starts += base;
frustum_starts
+=
base
;
//
frustum_ends += base;
frustum_ends
+=
base
;
//
//
Second round
// Second round
//
j = 0;
j
=
0
;
//
t0 = near;
t0
=
near
;
//
t1 = t0 + dt;
t1
=
t0
+
dt
;
//
t_mid = (t0 + t1) / 2.;
t_mid
=
(
t0
+
t1
)
/
2.
;
//
while (t_mid < far && j < marching_samples) {
while
(
t_mid
<
far
&&
j
<
marching_samples
)
{
//
// current center
// current center
//
const float x = ox + t_mid * dx;
const
float
x
=
ox
+
t_mid
*
dx
;
//
const float y = oy + t_mid * dy;
const
float
y
=
oy
+
t_mid
*
dy
;
//
const float z = oz + t_mid * dz;
const
float
z
=
oz
+
t_mid
*
dz
;
//
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
if
(
grid_occupied_at
(
x
,
y
,
z
,
resx
,
resy
,
resz
,
aabb
,
occ_binary
))
{
//
frustum_origins[j * 3 + 0] = ox;
frustum_origins
[
j
*
3
+
0
]
=
ox
;
//
frustum_origins[j * 3 + 1] = oy;
frustum_origins
[
j
*
3
+
1
]
=
oy
;
//
frustum_origins[j * 3 + 2] = oz;
frustum_origins
[
j
*
3
+
2
]
=
oz
;
//
frustum_dirs[j * 3 + 0] = dx;
frustum_dirs
[
j
*
3
+
0
]
=
dx
;
//
frustum_dirs[j * 3 + 1] = dy;
frustum_dirs
[
j
*
3
+
1
]
=
dy
;
//
frustum_dirs[j * 3 + 2] = dz;
frustum_dirs
[
j
*
3
+
2
]
=
dz
;
//
frustum_starts[j] = t0;
frustum_starts
[
j
]
=
t0
;
//
frustum_ends[j] = t1;
frustum_ends
[
j
]
=
t1
;
//
++j;
++
j
;
//
// march to next sample
// march to next sample
//
t0 = t1;
t0
=
t1
;
//
t1 = t0 + dt;
t1
=
t0
+
dt
;
//
t_mid = (t0 + t1) * 0.5f;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
//
}
}
//
else {
else
{
//
// march to next sample
// march to next sample
//
t_mid = advance_to_next_voxel(
t_mid
=
advance_to_next_voxel
(
//
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
t_mid
,
x
,
y
,
z
,
dx
,
dy
,
dz
,
rdx
,
rdy
,
rdz
,
resx
,
resy
,
resz
,
dt
//
);
);
//
t0 = t_mid - dt * 0.5f;
t0
=
t_mid
-
dt
*
0.5
f
;
//
t1 = t_mid + dt * 0.5f;
t1
=
t_mid
+
dt
*
0.5
f
;
//
}
}
//
}
}
//
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
packed_info
[
ray_idx
*
3
+
0
]
=
i
;
// ray idx in {rays_o, rays_d}
//
packed_info[ray_idx * 3 + 1] = base; // point idx start.
packed_info
[
ray_idx
*
3
+
1
]
=
base
;
// point idx start.
//
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
packed_info
[
ray_idx
*
3
+
2
]
=
j
;
// point idx shift (actual marching samples).
return
;
return
;
}
}
...
@@ -233,62 +234,61 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -233,62 +234,61 @@ std::vector<torch::Tensor> ray_marching(
const
int
max_per_ray_samples
,
const
int
max_per_ray_samples
,
const
float
dt
const
float
dt
)
{
)
{
//
DEVICE_GUARD(rays_o);
DEVICE_GUARD
(
rays_o
);
//
CHECK_INPUT(rays_o);
CHECK_INPUT
(
rays_o
);
//
CHECK_INPUT(rays_d);
CHECK_INPUT
(
rays_d
);
//
CHECK_INPUT(t_min);
CHECK_INPUT
(
t_min
);
//
CHECK_INPUT(t_max);
CHECK_INPUT
(
t_max
);
//
CHECK_INPUT(aabb);
CHECK_INPUT
(
aabb
);
//
CHECK_INPUT(occ_binary);
CHECK_INPUT
(
occ_binary
);
//
const int n_rays = rays_o.size(0);
const
int
n_rays
=
rays_o
.
size
(
0
);
// //
const int threads = 256;
const
int
threads
=
256
;
// //
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
//
//
helper counter
// helper counter
//
torch::Tensor steps_counter = torch::zeros(
torch
::
Tensor
steps_counter
=
torch
::
zeros
(
//
{1}, rays_o.options().dtype(torch::kInt32));
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
//
torch::Tensor rays_counter = torch::zeros(
torch
::
Tensor
rays_counter
=
torch
::
zeros
(
//
{1}, rays_o.options().dtype(torch::kInt32));
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
//
//
output frustum samples
// output frustum samples
//
torch::Tensor packed_info = torch::zeros(
torch
::
Tensor
packed_info
=
torch
::
zeros
(
//
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
{
n_rays
,
3
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// ray_id, sample_id, num_samples
//
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
torch
::
Tensor
frustum_origins
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
//
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
torch
::
Tensor
frustum_dirs
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
//
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
torch
::
Tensor
frustum_starts
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
//
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
torch
::
Tensor
frustum_ends
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
//
kernel_raymarching<<<blocks, threads
, 0, at::cuda::getCurrentCUDAStream()
>>>(
kernel_raymarching
<<<
blocks
,
threads
>>>
(
//
// rays
// rays
//
n_rays,
n_rays
,
//
rays_o.data_ptr<float>(),
rays_o
.
data_ptr
<
float
>
(),
//
rays_d.data_ptr<float>(),
rays_d
.
data_ptr
<
float
>
(),
//
t_min.data_ptr<float>(),
t_min
.
data_ptr
<
float
>
(),
//
t_max.data_ptr<float>(),
t_max
.
data_ptr
<
float
>
(),
//
// density grid
// density grid
//
aabb.data_ptr<float>(),
aabb
.
data_ptr
<
float
>
(),
//
resolution[0].cast<int>(),
resolution
[
0
].
cast
<
int
>
(),
//
resolution[1].cast<int>(),
resolution
[
1
].
cast
<
int
>
(),
//
resolution[2].cast<int>(),
resolution
[
2
].
cast
<
int
>
(),
//
occ_binary.data_ptr<bool>(),
occ_binary
.
data_ptr
<
bool
>
(),
//
// sampling
// sampling
//
max_total_samples,
max_total_samples
,
//
max_per_ray_samples,
max_per_ray_samples
,
//
dt,
dt
,
//
// writable helpers
// writable helpers
//
steps_counter.data_ptr<int>(), // total samples.
steps_counter
.
data_ptr
<
int
>
(),
// total samples.
//
rays_counter.data_ptr<int>(), // total rays.
rays_counter
.
data_ptr
<
int
>
(),
// total rays.
//
packed_info.data_ptr<int>(),
packed_info
.
data_ptr
<
int
>
(),
//
frustum_origins.data_ptr<float>(),
frustum_origins
.
data_ptr
<
float
>
(),
//
frustum_dirs.data_ptr<float>(),
frustum_dirs
.
data_ptr
<
float
>
(),
//
frustum_starts.data_ptr<float>(),
frustum_starts
.
data_ptr
<
float
>
(),
//
frustum_ends.data_ptr<float>()
frustum_ends
.
data_ptr
<
float
>
()
//
);
);
// return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return
{
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
steps_counter
};
return
{};
}
}
nerfacc/occupancy_field.py
View file @
16324602
...
@@ -72,6 +72,7 @@ class OccupancyField(nn.Module):
...
@@ -72,6 +72,7 @@ class OccupancyField(nn.Module):
self
.
register_buffer
(
"aabb"
,
aabb
)
self
.
register_buffer
(
"aabb"
,
aabb
)
self
.
resolution
=
resolution
self
.
resolution
=
resolution
self
.
register_buffer
(
"resolution_tensor"
,
torch
.
tensor
(
resolution
))
self
.
num_dim
=
num_dim
self
.
num_dim
=
num_dim
self
.
num_cells
=
torch
.
tensor
(
resolution
).
prod
().
item
()
self
.
num_cells
=
torch
.
tensor
(
resolution
).
prod
().
item
()
...
@@ -107,7 +108,6 @@ class OccupancyField(nn.Module):
...
@@ -107,7 +108,6 @@ class OccupancyField(nn.Module):
if
n
<
len
(
occupied_indices
):
if
n
<
len
(
occupied_indices
):
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
device
)
selector
=
torch
.
randint
(
len
(
occupied_indices
),
(
n
,),
device
=
device
)
occupied_indices
=
occupied_indices
[
selector
]
occupied_indices
=
occupied_indices
[
selector
]
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
indices
=
torch
.
cat
([
uniform_indices
,
occupied_indices
],
dim
=
0
)
return
indices
return
indices
...
@@ -129,19 +129,19 @@ class OccupancyField(nn.Module):
...
@@ -129,19 +129,19 @@ class OccupancyField(nn.Module):
stage we change the sampling strategy to 1/4 unifromly sampled cells
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
together with 1/4 occupied cells.
"""
"""
resolution
=
torch
.
tensor
(
self
.
resolution
).
to
(
self
.
occ_grid
.
device
)
# sample cells
# sample cells
if
step
<
warmup_steps
:
if
step
<
warmup_steps
:
indices
=
self
.
_get_all_cells
()
indices
=
self
.
_get_all_cells
()
else
:
else
:
N
=
resolution
.
prod
().
item
()
//
4
N
=
self
.
num_cells
//
4
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
indices
=
self
.
_sample_uniform_and_occupied_cells
(
N
)
# infer occupancy: density * step_size
# infer occupancy: density * step_size
tmp_occ_grid
=
-
torch
.
ones_like
(
self
.
occ_grid
)
tmp_occ_grid
=
-
torch
.
ones_like
(
self
.
occ_grid
)
grid_coords
=
self
.
grid_coords
[
indices
]
grid_coords
=
self
.
grid_coords
[
indices
]
x
=
(
grid_coords
+
torch
.
rand_like
(
grid_coords
.
float
()))
/
resolution
x
=
(
grid_coords
+
torch
.
rand_like
(
grid_coords
.
float
())
)
/
self
.
resolution_tensor
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
bb_min
,
bb_max
=
torch
.
split
(
self
.
aabb
,
[
self
.
num_dim
,
self
.
num_dim
],
dim
=
0
)
x
=
x
*
(
bb_max
-
bb_min
)
+
bb_min
x
=
x
*
(
bb_max
-
bb_min
)
+
bb_min
tmp_occ_grid
[
indices
]
=
self
.
occ_eval_fn
(
x
).
squeeze
(
-
1
)
tmp_occ_grid
[
indices
]
=
self
.
occ_eval_fn
(
x
).
squeeze
(
-
1
)
...
@@ -152,8 +152,8 @@ class OccupancyField(nn.Module):
...
@@ -152,8 +152,8 @@ class OccupancyField(nn.Module):
self
.
occ_grid
[
ema_mask
]
*
ema_decay
,
tmp_occ_grid
[
ema_mask
]
self
.
occ_grid
[
ema_mask
]
*
ema_decay
,
tmp_occ_grid
[
ema_mask
]
)
)
self
.
occ_grid_mean
=
self
.
occ_grid
.
mean
()
self
.
occ_grid_mean
=
self
.
occ_grid
.
mean
()
self
.
occ_grid_binary
=
self
.
occ_grid
>
min
(
self
.
occ_grid_binary
=
self
.
occ_grid
>
torch
.
clamp
(
self
.
occ_grid_mean
.
item
(),
occ_threshold
self
.
occ_grid_mean
,
max
=
occ_threshold
)
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
nerfacc/volumetric_rendering.py
View file @
16324602
...
@@ -16,6 +16,7 @@ def volumetric_rendering(
...
@@ -16,6 +16,7 @@ def volumetric_rendering(
render_bkgd
:
torch
.
Tensor
=
None
,
render_bkgd
:
torch
.
Tensor
=
None
,
render_n_samples
:
int
=
1024
,
render_n_samples
:
int
=
1024
,
render_est_n_samples
:
int
=
None
,
render_est_n_samples
:
int
=
None
,
render_step_size
:
int
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""A *fast* version of differentiable volumetric rendering."""
"""A *fast* version of differentiable volumetric rendering."""
...
@@ -23,8 +24,6 @@ def volumetric_rendering(
...
@@ -23,8 +24,6 @@ def volumetric_rendering(
if
render_bkgd
is
None
:
if
render_bkgd
is
None
:
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
render_bkgd
=
torch
.
ones
(
3
,
device
=
device
)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o
=
rays_o
.
contiguous
()
rays_o
=
rays_o
.
contiguous
()
rays_d
=
rays_d
.
contiguous
()
rays_d
=
rays_d
.
contiguous
()
scene_aabb
=
scene_aabb
.
contiguous
()
scene_aabb
=
scene_aabb
.
contiguous
()
...
@@ -36,22 +35,22 @@ def volumetric_rendering(
...
@@ -36,22 +35,22 @@ def volumetric_rendering(
render_total_samples
=
n_rays
*
render_n_samples
render_total_samples
=
n_rays
*
render_n_samples
else
:
else
:
render_total_samples
=
render_est_n_samples
render_total_samples
=
render_est_n_samples
if
render_step_size
is
None
:
# Note: CPU<->GPU is not idea, try to pre-define it outside this function.
render_step_size
=
(
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
t_min
,
t_max
=
ray_aabb_intersect
(
rays_o
,
rays_d
,
scene_aabb
)
t_min
,
t_max
=
ray_aabb_intersect
(
rays_o
,
rays_d
,
scene_aabb
)
# t_min = torch.clamp(t_min, max=1e10)
# t_max = torch.clamp(t_max, max=1e10)
(
(
#
packed_info,
packed_info
,
#
frustum_origins,
frustum_origins
,
#
frustum_dirs,
frustum_dirs
,
#
frustum_starts,
frustum_starts
,
#
frustum_ends,
frustum_ends
,
#
steps_counter,
steps_counter
,
)
=
ray_marching
(
)
=
ray_marching
(
# rays
# rays
rays_o
,
rays_o
,
...
@@ -68,43 +67,41 @@ def volumetric_rendering(
...
@@ -68,43 +67,41 @@ def volumetric_rendering(
render_step_size
,
render_step_size
,
)
)
# # squeeze valid samples
# squeeze valid samples
# total_samples = max(packed_info[:, -1].sum(), 1)
total_samples
=
max
(
packed_info
[:,
-
1
].
sum
(),
1
)
# total_samples = int(math.ceil(total_samples / 128.0)) * 128
frustum_origins
=
frustum_origins
[:
total_samples
]
# frustum_origins = frustum_origins[:total_samples]
frustum_dirs
=
frustum_dirs
[:
total_samples
]
# frustum_dirs = frustum_dirs[:total_samples]
frustum_starts
=
frustum_starts
[:
total_samples
]
# frustum_starts = frustum_starts[:total_samples]
frustum_ends
=
frustum_ends
[:
total_samples
]
# frustum_ends = frustum_ends[:total_samples]
# frustum_positions = (
# frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
# )
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
frustum_positions
=
(
# rgbs, densities = query_results[0], query_results[1]
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
)
# (
query_results
=
query_fn
(
frustum_positions
,
frustum_dirs
,
**
kwargs
)
# accumulated_weight,
rgbs
,
densities
=
query_results
[
0
],
query_results
[
1
]
# accumulated_depth,
(
# accumulated_color,
accumulated_weight
,
# alive_ray_mask,
accumulated_depth
,
# compact_steps_counter,
accumulated_color
,
# ) = VolumeRenderer.apply(
alive_ray_mask
,
# packed_info,
compact_steps_counter
,
# frustum_starts,
)
=
VolumeRenderer
.
apply
(
# frustum_ends,
packed_info
,
# densities.contiguous(),
frustum_starts
,
# rgbs.contiguous(),
frustum_ends
,
# )
densities
.
contiguous
(),
rgbs
.
contiguous
(),
)
#
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_depth
=
torch
.
clip
(
accumulated_depth
,
t_min
[:,
None
],
t_max
[:,
None
])
#
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
accumulated_color
=
accumulated_color
+
render_bkgd
*
(
1.0
-
accumulated_weight
)
#
return (
return
(
#
accumulated_color,
accumulated_color
,
#
accumulated_depth,
accumulated_depth
,
#
accumulated_weight,
accumulated_weight
,
#
alive_ray_mask,
alive_ray_mask
,
#
steps_counter,
steps_counter
,
#
compact_steps_counter,
compact_steps_counter
,
#
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment