Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
2e7ad6e0
Commit
2e7ad6e0
authored
Nov 11, 2022
by
Ruilong Li
Browse files
proposal seems working
parent
b4286720
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
606 additions
and
36 deletions
+606
-36
examples/train_ngp_nerf_proposal.py
examples/train_ngp_nerf_proposal.py
+440
-0
nerfacc/cdf.py
nerfacc/cdf.py
+1
-1
nerfacc/cuda/csrc/cdf.cu
nerfacc/cuda/csrc/cdf.cu
+26
-19
nerfacc/ray_marching.py
nerfacc/ray_marching.py
+25
-2
tests/test_resampling.py
tests/test_resampling.py
+114
-14
No files found.
examples/train_ngp_nerf_proposal.py
0 → 100644
View file @
2e7ad6e0
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import
argparse
import
math
import
os
import
random
import
time
from
typing
import
Optional
import
imageio
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
tqdm
from
datasets.utils
import
Rays
,
namedtuple_map
from
radiance_fields.ngp
import
NGPradianceField
from
utils
import
set_random_seed
from
nerfacc
import
ContractionType
,
ray_marching
,
rendering
from
nerfacc.cuda
import
ray_pdf_query
def
set_random_seed
(
seed
):
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
def
render_image
(
# scene
radiance_field
:
torch
.
nn
.
Module
,
proposal_nets
:
torch
.
nn
.
Module
,
rays
:
Rays
,
scene_aabb
:
torch
.
Tensor
,
# rendering options
near_plane
:
Optional
[
float
]
=
None
,
far_plane
:
Optional
[
float
]
=
None
,
render_step_size
:
float
=
1e-3
,
render_bkgd
:
Optional
[
torch
.
Tensor
]
=
None
,
cone_angle
:
float
=
0.0
,
alpha_thre
:
float
=
0.0
,
# test options
test_chunk_size
:
int
=
8192
,
):
"""Render the pixels of an image."""
rays_shape
=
rays
.
origins
.
shape
if
len
(
rays_shape
)
==
3
:
height
,
width
,
_
=
rays_shape
num_rays
=
height
*
width
rays
=
namedtuple_map
(
lambda
r
:
r
.
reshape
([
num_rays
]
+
list
(
r
.
shape
[
2
:])),
rays
)
else
:
num_rays
,
_
=
rays_shape
def
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
,
net
=
None
):
ray_indices
=
ray_indices
.
long
()
t_origins
=
chunk_rays
.
origins
[
ray_indices
]
t_dirs
=
chunk_rays
.
viewdirs
[
ray_indices
]
positions
=
t_origins
+
t_dirs
*
(
t_starts
+
t_ends
)
/
2.0
if
net
is
not
None
:
return
net
.
query_density
(
positions
)
else
:
return
radiance_field
.
query_density
(
positions
)
def
rgb_sigma_fn
(
t_starts
,
t_ends
,
ray_indices
):
ray_indices
=
ray_indices
.
long
()
t_origins
=
chunk_rays
.
origins
[
ray_indices
]
t_dirs
=
chunk_rays
.
viewdirs
[
ray_indices
]
positions
=
t_origins
+
t_dirs
*
(
t_starts
+
t_ends
)
/
2.0
return
radiance_field
(
positions
,
t_dirs
)
results
=
[]
chunk
=
(
torch
.
iinfo
(
torch
.
int32
).
max
if
radiance_field
.
training
else
test_chunk_size
)
for
i
in
range
(
0
,
num_rays
,
chunk
):
chunk_rays
=
namedtuple_map
(
lambda
r
:
r
[
i
:
i
+
chunk
],
rays
)
packed_info
,
t_starts
,
t_ends
,
proposal_sample_list
=
ray_marching
(
chunk_rays
.
origins
,
chunk_rays
.
viewdirs
,
scene_aabb
=
scene_aabb
,
grid
=
None
,
proposal_nets
=
proposal_nets
,
sigma_fn
=
sigma_fn
,
near_plane
=
near_plane
,
far_plane
=
far_plane
,
render_step_size
=
render_step_size
,
stratified
=
radiance_field
.
training
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
)
rgb
,
opacity
,
depth
,
weights
=
rendering
(
rgb_sigma_fn
,
packed_info
,
t_starts
,
t_ends
,
render_bkgd
=
render_bkgd
,
)
if
radiance_field
.
training
:
proposal_sample_list
.
append
(
(
packed_info
,
t_starts
,
t_ends
,
weights
)
)
chunk_results
=
[
rgb
,
opacity
,
depth
,
len
(
t_starts
)]
results
.
append
(
chunk_results
)
colors
,
opacities
,
depths
,
n_rendering_samples
=
[
torch
.
cat
(
r
,
dim
=
0
)
if
isinstance
(
r
[
0
],
torch
.
Tensor
)
else
r
for
r
in
zip
(
*
results
)
]
return
(
colors
.
view
((
*
rays_shape
[:
-
1
],
-
1
)),
opacities
.
view
((
*
rays_shape
[:
-
1
],
-
1
)),
depths
.
view
((
*
rays_shape
[:
-
1
],
-
1
)),
sum
(
n_rendering_samples
),
proposal_sample_list
if
radiance_field
.
training
else
None
,
)
if
__name__
==
"__main__"
:
device
=
"cuda:0"
set_random_seed
(
42
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--train_split"
,
type
=
str
,
default
=
"trainval"
,
choices
=
[
"train"
,
"trainval"
],
help
=
"which train split to use"
,
)
parser
.
add_argument
(
"--scene"
,
type
=
str
,
default
=
"lego"
,
choices
=
[
# nerf synthetic
"chair"
,
"drums"
,
"ficus"
,
"hotdog"
,
"lego"
,
"materials"
,
"mic"
,
"ship"
,
# mipnerf360 unbounded
"garden"
,
"bicycle"
,
"bonsai"
,
"counter"
,
"kitchen"
,
"room"
,
"stump"
,
],
help
=
"which scene to use"
,
)
parser
.
add_argument
(
"--aabb"
,
type
=
lambda
s
:
[
float
(
item
)
for
item
in
s
.
split
(
","
)],
default
=
"-1.5,-1.5,-1.5,1.5,1.5,1.5"
,
help
=
"delimited list input"
,
)
parser
.
add_argument
(
"--test_chunk_size"
,
type
=
int
,
default
=
8192
,
)
parser
.
add_argument
(
"--unbounded"
,
action
=
"store_true"
,
help
=
"whether to use unbounded rendering"
,
)
parser
.
add_argument
(
"--auto_aabb"
,
action
=
"store_true"
,
help
=
"whether to automatically compute the aabb"
,
)
parser
.
add_argument
(
"--cone_angle"
,
type
=
float
,
default
=
0.0
)
args
=
parser
.
parse_args
()
render_n_samples
=
256
# setup the dataset
train_dataset_kwargs
=
{}
test_dataset_kwargs
=
{}
if
args
.
unbounded
:
from
datasets.nerf_360_v2
import
SubjectLoader
data_root_fp
=
"/home/ruilongli/data/360_v2/"
target_sample_batch_size
=
1
<<
20
train_dataset_kwargs
=
{
"color_bkgd_aug"
:
"random"
,
"factor"
:
4
}
test_dataset_kwargs
=
{
"factor"
:
4
}
else
:
from
datasets.nerf_synthetic
import
SubjectLoader
data_root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
target_sample_batch_size
=
1
<<
20
train_dataset
=
SubjectLoader
(
subject_id
=
args
.
scene
,
root_fp
=
data_root_fp
,
split
=
args
.
train_split
,
num_rays
=
target_sample_batch_size
//
render_n_samples
,
**
train_dataset_kwargs
,
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
train_dataset
.
camtoworlds
=
train_dataset
.
camtoworlds
.
to
(
device
)
train_dataset
.
K
=
train_dataset
.
K
.
to
(
device
)
test_dataset
=
SubjectLoader
(
subject_id
=
args
.
scene
,
root_fp
=
data_root_fp
,
split
=
"test"
,
num_rays
=
None
,
**
test_dataset_kwargs
,
)
test_dataset
.
images
=
test_dataset
.
images
.
to
(
device
)
test_dataset
.
camtoworlds
=
test_dataset
.
camtoworlds
.
to
(
device
)
test_dataset
.
K
=
test_dataset
.
K
.
to
(
device
)
if
args
.
auto_aabb
:
camera_locs
=
torch
.
cat
(
[
train_dataset
.
camtoworlds
,
test_dataset
.
camtoworlds
]
)[:,
:
3
,
-
1
]
args
.
aabb
=
torch
.
cat
(
[
camera_locs
.
min
(
dim
=
0
).
values
,
camera_locs
.
max
(
dim
=
0
).
values
]
).
tolist
()
print
(
"Using auto aabb"
,
args
.
aabb
)
# setup the scene bounding box.
if
args
.
unbounded
:
print
(
"Using unbounded rendering"
)
contraction_type
=
ContractionType
.
UN_BOUNDED_SPHERE
# contraction_type = ContractionType.UN_BOUNDED_TANH
scene_aabb
=
None
near_plane
=
0.2
far_plane
=
1e4
render_step_size
=
1e-2
alpha_thre
=
1e-2
else
:
contraction_type
=
ContractionType
.
AABB
scene_aabb
=
torch
.
tensor
(
args
.
aabb
,
dtype
=
torch
.
float32
,
device
=
device
)
near_plane
=
None
far_plane
=
None
render_step_size
=
(
(
scene_aabb
[
3
:]
-
scene_aabb
[:
3
]).
max
()
*
math
.
sqrt
(
3
)
/
render_n_samples
).
item
()
alpha_thre
=
0.0
proposal_nets
=
torch
.
nn
.
ModuleList
(
[
NGPradianceField
(
aabb
=
args
.
aabb
,
use_viewdirs
=
False
,
hidden_dim
=
16
,
max_res
=
64
,
geo_feat_dim
=
0
,
n_levels
=
5
,
log2_hashmap_size
=
17
,
),
# NGPradianceField(
# aabb=args.aabb,
# use_viewdirs=False,
# hidden_dim=16,
# max_res=256,
# geo_feat_dim=0,
# n_levels=5,
# log2_hashmap_size=17,
# ),
]
).
to
(
device
)
# setup the radiance field we want to train.
max_steps
=
20000
grad_scaler
=
torch
.
cuda
.
amp
.
GradScaler
(
2
**
10
)
radiance_field
=
NGPradianceField
(
aabb
=
args
.
aabb
,
unbounded
=
args
.
unbounded
,
).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
(
list
(
radiance_field
.
parameters
())
+
list
(
proposal_nets
.
parameters
()),
lr
=
1e-2
,
eps
=
1e-15
,
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
[
max_steps
//
2
,
max_steps
*
3
//
4
,
max_steps
*
9
//
10
],
gamma
=
0.33
,
)
# training
step
=
0
tic
=
time
.
time
()
for
epoch
in
range
(
10000000
):
for
i
in
range
(
len
(
train_dataset
)):
radiance_field
.
train
()
data
=
train_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
# render
(
rgb
,
acc
,
depth
,
n_rendering_samples
,
proposal_sample_list
,
)
=
render_image
(
radiance_field
,
proposal_nets
,
rays
,
scene_aabb
,
# rendering options
near_plane
=
near_plane
,
far_plane
=
far_plane
,
render_step_size
=
render_step_size
,
render_bkgd
=
render_bkgd
,
cone_angle
=
args
.
cone_angle
,
alpha_thre
=
alpha_thre
,
)
if
n_rendering_samples
==
0
:
continue
# dynamic batch size for rays to keep sample batch size constant.
num_rays
=
len
(
pixels
)
num_rays
=
int
(
num_rays
*
(
target_sample_batch_size
/
float
(
n_rendering_samples
))
)
train_dataset
.
update_num_rays
(
num_rays
)
alive_ray_mask
=
acc
.
squeeze
(
-
1
)
>
0
# compute loss
loss
=
F
.
smooth_l1_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
(
packed_info
,
t_starts
,
t_ends
,
weights
,
)
=
proposal_sample_list
[
-
1
]
for
(
proposal_packed_info
,
proposal_t_starts
,
proposal_t_ends
,
proposal_weights
,
)
in
proposal_sample_list
[:
-
1
]:
proposal_weights_gt
=
ray_pdf_query
(
packed_info
,
t_starts
,
t_ends
,
weights
.
detach
(),
proposal_packed_info
,
proposal_t_starts
,
proposal_t_ends
,
).
detach
()
torch
.
cuda
.
synchronize
()
loss_interval
=
(
torch
.
clamp
(
proposal_weights_gt
-
proposal_weights
,
min
=
0
)
)
**
2
/
(
proposal_weights
+
torch
.
finfo
(
torch
.
float32
).
eps
)
loss_interval
=
loss_interval
.
mean
()
loss
+=
loss_interval
*
1.0
optimizer
.
zero_grad
()
# do not unscale it because we are using Adam.
grad_scaler
.
scale
(
loss
).
backward
()
optimizer
.
step
()
scheduler
.
step
()
if
step
%
100
==
0
:
elapsed_time
=
time
.
time
()
-
tic
loss
=
F
.
mse_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s | step=
{
step
}
| "
f
"loss=
{
loss
:.
5
f
}
| loss_interval=
{
loss_interval
:.
5
f
}
"
f
"alive_ray_mask=
{
alive_ray_mask
.
long
().
sum
():
d
}
| "
f
"n_rendering_samples=
{
n_rendering_samples
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
|"
)
if
step
>=
0
and
step
%
1000
==
0
and
step
>
0
:
# evaluation
radiance_field
.
eval
()
psnrs
=
[]
with
torch
.
no_grad
():
for
i
in
tqdm
.
tqdm
(
range
(
len
(
test_dataset
))):
data
=
test_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
rays
=
data
[
"rays"
]
pixels
=
data
[
"pixels"
]
# rendering
rgb
,
acc
,
depth
,
_
,
_
=
render_image
(
radiance_field
,
proposal_nets
,
rays
,
scene_aabb
,
# rendering options
near_plane
=
near_plane
,
far_plane
=
far_plane
,
render_step_size
=
render_step_size
,
render_bkgd
=
render_bkgd
,
cone_angle
=
args
.
cone_angle
,
alpha_thre
=
alpha_thre
,
# test options
test_chunk_size
=
args
.
test_chunk_size
,
)
mse
=
F
.
mse_loss
(
rgb
,
pixels
)
psnr
=
-
10.0
*
torch
.
log
(
mse
)
/
np
.
log
(
10.0
)
psnrs
.
append
(
psnr
.
item
())
imageio
.
imwrite
(
"acc_binary_test.png"
,
((
acc
>
0
).
float
().
cpu
().
numpy
()
*
255
).
astype
(
np
.
uint8
),
)
imageio
.
imwrite
(
"rgb_test.png"
,
(
rgb
.
cpu
().
numpy
()
*
255
).
astype
(
np
.
uint8
),
)
break
psnr_avg
=
sum
(
psnrs
)
/
len
(
psnrs
)
print
(
f
"evaluation: psnr_avg=
{
psnr_avg
}
"
)
train_dataset
.
training
=
True
if
step
==
max_steps
:
print
(
"training stops"
)
exit
()
step
+=
1
nerfacc/cdf.py
View file @
2e7ad6e0
...
@@ -37,7 +37,7 @@ def ray_resampling(
...
@@ -37,7 +37,7 @@ def ray_resampling(
resampled_t_starts
,
resampled_t_starts
,
resampled_t_ends
,
resampled_t_ends
,
)
=
_C
.
ray_resampling
(
)
=
_C
.
ray_resampling
(
packed_info
.
contiguous
(),
packed_info
.
contiguous
()
.
int
()
,
t_starts
.
contiguous
(),
t_starts
.
contiguous
(),
t_ends
.
contiguous
(),
t_ends
.
contiguous
(),
weights
.
contiguous
(),
weights
.
contiguous
(),
...
...
nerfacc/cuda/csrc/cdf.cu
View file @
2e7ad6e0
...
@@ -94,7 +94,7 @@ __global__ void cdf_resampling_kernel(
...
@@ -94,7 +94,7 @@ __global__ void cdf_resampling_kernel(
const
int
*
packed_info
,
// input ray & point indices.
const
int
*
packed_info
,
// input ray & point indices.
const
scalar_t
*
starts
,
// input start t
const
scalar_t
*
starts
,
// input start t
const
scalar_t
*
ends
,
// input end t
const
scalar_t
*
ends
,
// input end t
const
scalar_t
*
w
eights
,
// transmittance weights
const
scalar_t
*
w
,
// transmittance weights
const
int
*
resample_packed_info
,
const
int
*
resample_packed_info
,
scalar_t
*
resample_starts
,
scalar_t
*
resample_starts
,
scalar_t
*
resample_ends
)
scalar_t
*
resample_ends
)
...
@@ -111,25 +111,26 @@ __global__ void cdf_resampling_kernel(
...
@@ -111,25 +111,26 @@ __global__ void cdf_resampling_kernel(
starts
+=
base
;
starts
+=
base
;
ends
+=
base
;
ends
+=
base
;
w
eights
+=
base
;
w
+=
base
;
resample_starts
+=
resample_base
;
resample_starts
+=
resample_base
;
resample_ends
+=
resample_base
;
resample_ends
+=
resample_base
;
// normalize weights **per ray**
// normalize weights **per ray**
scalar_t
w
eights
_sum
=
0.0
f
;
scalar_t
w_sum
=
0.0
f
;
for
(
int
j
=
0
;
j
<
steps
;
j
++
)
for
(
int
j
=
0
;
j
<
steps
;
j
++
)
w
eights
_sum
+=
w
eights
[
j
];
w_sum
+=
w
[
j
];
scalar_t
padding
=
fmaxf
(
1e-
5
f
-
weights_sum
,
0.0
f
);
//
scalar_t padding = fmaxf(1e-
10
f - weights_sum, 0.0f);
scalar_t
padding_step
=
padding
/
steps
;
//
scalar_t padding_step = padding / steps;
weights_sum
+=
padding
;
//
weights_sum += padding;
int
num_bins
=
resample_steps
+
1
;
int
num_endpoints
=
resample_steps
+
1
;
scalar_t
cdf_step_size
=
(
1.0
f
-
1.0
/
num_bins
)
/
resample_steps
;
scalar_t
cdf_pad
=
1.0
f
/
(
2
*
num_endpoints
);
scalar_t
cdf_step_size
=
(
1.0
f
-
2
*
cdf_pad
)
/
resample_steps
;
int
idx
=
0
,
j
=
0
;
int
idx
=
0
,
j
=
0
;
scalar_t
cdf_prev
=
0.0
f
,
cdf_next
=
(
weights
[
idx
]
+
padding_step
)
/
weights
_sum
;
scalar_t
cdf_prev
=
0.0
f
,
cdf_next
=
w
[
idx
]
/
w
_sum
;
scalar_t
cdf_u
=
1.0
/
(
2
*
num_bins
)
;
scalar_t
cdf_u
=
cdf_pad
;
while
(
j
<
num_
b
ins
)
while
(
j
<
num_
endpo
in
t
s
)
{
{
if
(
cdf_u
<
cdf_next
)
if
(
cdf_u
<
cdf_next
)
{
{
...
@@ -137,26 +138,32 @@ __global__ void cdf_resampling_kernel(
...
@@ -137,26 +138,32 @@ __global__ void cdf_resampling_kernel(
// resample in this interval
// resample in this interval
scalar_t
scaling
=
(
ends
[
idx
]
-
starts
[
idx
])
/
(
cdf_next
-
cdf_prev
);
scalar_t
scaling
=
(
ends
[
idx
]
-
starts
[
idx
])
/
(
cdf_next
-
cdf_prev
);
scalar_t
t
=
(
cdf_u
-
cdf_prev
)
*
scaling
+
starts
[
idx
];
scalar_t
t
=
(
cdf_u
-
cdf_prev
)
*
scaling
+
starts
[
idx
];
if
(
j
<
num_bins
-
1
)
// if (j == 100) {
// printf(
// "cdf_u: %.10f, cdf_next: %.10f, cdf_prev: %.10f, scaling: %.10f, t: %.10f, starts[idx]: %.10f, ends[idx]: %.10f\n",
// cdf_u, cdf_next, cdf_prev, scaling, t, starts[idx], ends[idx]);
// }
if
(
j
<
num_endpoints
-
1
)
resample_starts
[
j
]
=
t
;
resample_starts
[
j
]
=
t
;
if
(
j
>
0
)
if
(
j
>
0
)
resample_ends
[
j
-
1
]
=
t
;
resample_ends
[
j
-
1
]
=
t
;
// going further to next resample
// going further to next resample
cdf_u
+=
cdf_step_size
;
//
cdf_u += cdf_step_size;
j
+=
1
;
j
+=
1
;
cdf_u
=
j
*
cdf_step_size
+
cdf_pad
;
}
}
else
else
{
{
// going to next interval
// going to next interval
idx
+=
1
;
idx
+=
1
;
cdf_prev
=
cdf_next
;
cdf_prev
=
cdf_next
;
cdf_next
+=
(
weights
[
idx
]
+
padding_step
)
/
weights
_sum
;
cdf_next
+=
w
[
idx
]
/
w
_sum
;
}
}
}
}
if
(
j
!=
num_
b
ins
)
//
if (j != num_
endpo
in
t
s)
{
//
{
printf
(
"Error: %d %d %f
\n
"
,
j
,
num_
b
ins
,
weights_sum
);
//
printf("Error: %d %d %f\n", j, num_
endpo
in
t
s, weights_sum);
}
//
}
return
;
return
;
}
}
...
...
nerfacc/ray_marching.py
View file @
2e7ad6e0
...
@@ -4,10 +4,12 @@ import torch
...
@@ -4,10 +4,12 @@ import torch
import
nerfacc.cuda
as
_C
import
nerfacc.cuda
as
_C
from
.cdf
import
ray_resampling
from
.contraction
import
ContractionType
from
.contraction
import
ContractionType
from
.grid
import
Grid
from
.grid
import
Grid
from
.intersection
import
ray_aabb_intersect
from
.intersection
import
ray_aabb_intersect
from
.vol_rendering
import
render_visibility
from
.pack
import
unpack_info
from
.vol_rendering
import
render_visibility
,
render_weight_from_density
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -24,6 +26,7 @@ def ray_marching(
...
@@ -24,6 +26,7 @@ def ray_marching(
# sigma/alpha function for skipping invisible space
# sigma/alpha function for skipping invisible space
sigma_fn
:
Optional
[
Callable
]
=
None
,
sigma_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
proposal_nets
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
early_stop_eps
:
float
=
1e-4
,
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
alpha_thre
:
float
=
0.0
,
# rendering options
# rendering options
...
@@ -189,6 +192,23 @@ def ray_marching(
...
@@ -189,6 +192,23 @@ def ray_marching(
cone_angle
,
cone_angle
,
)
)
if
proposal_nets
is
not
None
:
proposal_sample_list
=
[]
# resample with proposal nets
for
net
,
num_samples
in
zip
(
proposal_nets
,
[
48
]):
ray_indices
=
unpack_info
(
packed_info
)
with
torch
.
enable_grad
():
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
weights
=
render_weight_from_density
(
packed_info
,
t_starts
,
t_ends
,
sigmas
,
early_stop_eps
=
0
)
proposal_sample_list
.
append
(
(
packed_info
,
t_starts
,
t_ends
,
weights
)
)
packed_info
,
t_starts
,
t_ends
=
ray_resampling
(
packed_info
,
t_starts
,
t_ends
,
weights
,
n_samples
=
num_samples
)
# skip invisible space
# skip invisible space
if
sigma_fn
is
not
None
or
alpha_fn
is
not
None
:
if
sigma_fn
is
not
None
or
alpha_fn
is
not
None
:
# Query sigma without gradients
# Query sigma without gradients
...
@@ -218,4 +238,7 @@ def ray_marching(
...
@@ -218,4 +238,7 @@ def ray_marching(
t_ends
[
masks
],
t_ends
[
masks
],
)
)
return
ray_indices
,
t_starts
,
t_ends
if
proposal_nets
is
not
None
:
return
packed_info
,
t_starts
,
t_ends
,
proposal_sample_list
else
:
return
packed_info
,
t_starts
,
t_ends
tests/test_resampling.py
View file @
2e7ad6e0
import
pytest
import
pytest
import
torch
import
torch
from
functorch
import
vmap
from
nerfacc
import
pack_info
,
ray_marching
,
ray_resampling
from
nerfacc
import
pack_info
,
ray_marching
,
ray_resampling
from
nerfacc.cuda
import
ray_pdf_query
from
nerfacc.cuda
import
ray_pdf_query
device
=
"cuda:0"
device
=
"cuda:0"
batch_size
=
128
batch_size
=
128
eps
=
torch
.
finfo
(
torch
.
float32
).
eps
def
_interp
(
x
,
xp
,
fp
):
"""One-dimensional linear interpolation for monotonically increasing sample
points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Args:
x: the :math:`x`-coordinates at which to evaluate the interpolated
values.
xp: the :math:`x`-coordinates of the data points, must be increasing.
fp: the :math:`y`-coordinates of the data points, same length as `xp`.
Returns:
the interpolated values, same size as `x`.
"""
xp
=
xp
.
contiguous
()
x
=
x
.
contiguous
()
m
=
(
fp
[
1
:]
-
fp
[:
-
1
])
/
(
xp
[
1
:]
-
xp
[:
-
1
])
b
=
fp
[:
-
1
]
-
(
m
*
xp
[:
-
1
])
indices
=
torch
.
searchsorted
(
xp
,
x
,
right
=
True
)
-
1
indices
=
torch
.
clamp
(
indices
,
0
,
len
(
m
)
-
1
)
return
m
[
indices
]
*
x
+
b
[
indices
]
def
_integrate_weights
(
w
):
"""Compute the cumulative sum of w, assuming all weight vectors sum to 1.
The output's size on the last dimension is one greater than that of the input,
because we're computing the integral corresponding to the endpoints of a step
function, not the integral of the interior/bin values.
Args:
w: Tensor, which will be integrated along the last axis. This is assumed to
sum to 1 along the last axis, and this function will (silently) break if
that is not the case.
Returns:
cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1
"""
cw
=
torch
.
clamp
(
torch
.
cumsum
(
w
[...,
:
-
1
],
dim
=-
1
),
max
=
1
)
shape
=
cw
.
shape
[:
-
1
]
+
(
1
,)
# Ensure that the CDF starts with exactly 0 and ends with exactly 1.
zeros
=
torch
.
zeros
(
shape
,
device
=
w
.
device
)
ones
=
torch
.
ones
(
shape
,
device
=
w
.
device
)
cw0
=
torch
.
cat
([
zeros
,
cw
,
ones
],
dim
=-
1
)
return
cw0
def
_invert_cdf
(
u
,
t
,
w_logits
):
"""Invert the CDF defined by (t, w) at the points specified by u in [0, 1)."""
# Compute the PDF and CDF for each weight vector.
w
=
torch
.
softmax
(
w_logits
,
dim
=-
1
)
# w = torch.exp(w_logits)
# w = w / torch.sum(w, dim=-1, keepdim=True)
cw
=
_integrate_weights
(
w
)
# Interpolate into the inverse CDF.
t_new
=
vmap
(
_interp
)(
u
,
cw
,
t
)
return
t_new
def
_resampling
(
t
,
w_logits
,
num_samples
):
"""Piecewise-Constant PDF sampling from a step function.
Args:
t: [..., num_bins + 1], bin endpoint coordinates (must be sorted).
w_logits: [..., num_bins], logits corresponding to bin weights.
num_samples: int, the number of samples.
returns:
t_samples: [..., num_samples], the sampled t values
"""
pad
=
1
/
(
2
*
num_samples
)
u
=
torch
.
linspace
(
pad
,
1.0
-
pad
-
eps
,
num_samples
,
device
=
device
)
u
=
torch
.
broadcast_to
(
u
,
t
.
shape
[:
-
1
]
+
(
num_samples
,))
return
_invert_cdf
(
u
,
t
,
w_logits
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_resampling
():
def
test_resampling
():
rays_o
=
torch
.
rand
((
batch_size
,
3
),
device
=
device
)
batch_size
=
1024
rays_d
=
torch
.
randn
((
batch_size
,
3
),
device
=
device
)
num_bins
=
128
rays_d
=
rays_d
/
rays_d
.
norm
(
dim
=-
1
,
keepdim
=
True
)
num_samples
=
128
ray_indices
,
t_starts
,
t_ends
=
ray_marching
(
t
=
torch
.
randn
((
batch_size
,
num_bins
+
1
),
device
=
device
)
rays_o
,
t
=
torch
.
sort
(
t
,
dim
=-
1
).
values
rays_d
,
w_logits
=
torch
.
randn
((
batch_size
,
num_bins
),
device
=
device
)
*
0.1
near_plane
=
0.1
,
w
=
torch
.
softmax
(
w_logits
,
dim
=-
1
)
far_plane
=
1.0
,
masks
=
w_logits
>
0
render_step_size
=
1e-3
,
w_logits
[
~
masks
]
=
-
torch
.
inf
t_samples
=
_resampling
(
t
,
w_logits
,
num_samples
+
1
)
t_starts
=
t
[:,
:
-
1
][
masks
].
unsqueeze
(
-
1
)
t_ends
=
t
[:,
1
:][
masks
].
unsqueeze
(
-
1
)
w_logits
=
w_logits
[
masks
]
w
=
w
[
masks
]
num_steps
=
masks
.
long
().
sum
(
dim
=-
1
)
cum_steps
=
torch
.
cumsum
(
num_steps
,
dim
=
0
)
packed_info
=
torch
.
stack
([
cum_steps
-
num_steps
,
num_steps
],
dim
=-
1
).
int
()
_
,
t_starts
,
t_ends
=
ray_resampling
(
packed_info
,
t_starts
,
t_ends
,
w
,
num_samples
)
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
batch_size
)
weights
=
torch
.
rand
((
t_starts
.
shape
[
0
],),
device
=
device
)
# print(
packed_info
,
t_starts
,
t_ends
=
ray_resampling
(
# (t_starts.view(batch_size, num_samples) - t_samples[:, :-1])
packed_info
,
t_starts
,
t_ends
,
weights
,
n_samples
=
32
# .abs()
# .max(),
# (t_ends.view(batch_size, num_samples) - t_samples[:, 1:]).abs().max(),
# )
assert
torch
.
allclose
(
t_starts
.
view
(
batch_size
,
num_samples
),
t_samples
[:,
:
-
1
],
atol
=
1e-3
)
assert
torch
.
allclose
(
t_ends
.
view
(
batch_size
,
num_samples
),
t_samples
[:,
1
:],
atol
=
1e-3
)
)
assert
t_starts
.
shape
==
t_ends
.
shape
==
(
batch_size
*
32
,
1
)
def
test_pdf_query
():
def
test_pdf_query
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment