Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
ad2a0079
Commit
ad2a0079
authored
Nov 16, 2022
by
Ruilong Li
Browse files
proposal_nets_require_grads: 7k, 226s, 34.96db, loss 57
parent
6f7f9fb0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
373 additions
and
193 deletions
+373
-193
examples/train_ngp_nerf_proposal.py
examples/train_ngp_nerf_proposal.py
+93
-79
nerfacc/cuda/__init__.py
nerfacc/cuda/__init__.py
+1
-0
nerfacc/cuda/csrc/pybind.cu
nerfacc/cuda/csrc/pybind.cu
+10
-0
nerfacc/cuda/csrc/ray_marching.cu
nerfacc/cuda/csrc/ray_marching.cu
+157
-2
nerfacc/ray_marching.py
nerfacc/ray_marching.py
+112
-112
No files found.
examples/train_ngp_nerf_proposal.py
View file @
ad2a0079
...
@@ -42,6 +42,7 @@ def render_image(
...
@@ -42,6 +42,7 @@ def render_image(
render_bkgd
:
Optional
[
torch
.
Tensor
]
=
None
,
render_bkgd
:
Optional
[
torch
.
Tensor
]
=
None
,
cone_angle
:
float
=
0.0
,
cone_angle
:
float
=
0.0
,
alpha_thre
:
float
=
0.0
,
alpha_thre
:
float
=
0.0
,
proposal_nets_require_grads
:
bool
=
True
,
# test options
# test options
test_chunk_size
:
int
=
8192
,
test_chunk_size
:
int
=
8192
,
):
):
...
@@ -94,6 +95,7 @@ def render_image(
...
@@ -94,6 +95,7 @@ def render_image(
stratified
=
radiance_field
.
training
,
stratified
=
radiance_field
.
training
,
cone_angle
=
cone_angle
,
cone_angle
=
cone_angle
,
alpha_thre
=
alpha_thre
,
alpha_thre
=
alpha_thre
,
proposal_nets_require_grads
=
proposal_nets_require_grads
,
)
)
rgb
,
opacity
,
depth
,
weights
=
rendering
(
rgb
,
opacity
,
depth
,
weights
=
rendering
(
t_starts
,
t_starts
,
...
@@ -312,6 +314,8 @@ if __name__ == "__main__":
...
@@ -312,6 +314,8 @@ if __name__ == "__main__":
radiance_field
.
train
()
radiance_field
.
train
()
proposal_nets
.
train
()
proposal_nets
.
train
()
# @profile
def
_train
():
data
=
train_dataset
[
i
]
data
=
train_dataset
[
i
]
render_bkgd
=
data
[
"color_bkgd"
]
render_bkgd
=
data
[
"color_bkgd"
]
...
@@ -337,9 +341,10 @@ if __name__ == "__main__":
...
@@ -337,9 +341,10 @@ if __name__ == "__main__":
render_bkgd
=
render_bkgd
,
render_bkgd
=
render_bkgd
,
cone_angle
=
args
.
cone_angle
,
cone_angle
=
args
.
cone_angle
,
alpha_thre
=
min
(
alpha_thre
,
alpha_thre
*
step
/
1000
),
alpha_thre
=
min
(
alpha_thre
,
alpha_thre
*
step
/
1000
),
proposal_nets_require_grads
=
(
step
<
100
or
step
%
16
==
0
),
)
)
if
n_rendering_samples
==
0
:
#
if n_rendering_samples == 0:
continue
#
continue
# dynamic batch size for rays to keep sample batch size constant.
# dynamic batch size for rays to keep sample batch size constant.
num_rays
=
len
(
pixels
)
num_rays
=
len
(
pixels
)
...
@@ -351,7 +356,9 @@ if __name__ == "__main__":
...
@@ -351,7 +356,9 @@ if __name__ == "__main__":
alive_ray_mask
=
acc
.
squeeze
(
-
1
)
>
0
alive_ray_mask
=
acc
.
squeeze
(
-
1
)
>
0
# compute loss
# compute loss
loss
=
F
.
smooth_l1_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
loss
=
F
.
smooth_l1_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
]
)
(
(
packed_info
,
packed_info
,
...
@@ -377,7 +384,9 @@ if __name__ == "__main__":
...
@@ -377,7 +384,9 @@ if __name__ == "__main__":
).
detach
()
).
detach
()
loss_interval
=
(
loss_interval
=
(
torch
.
clamp
(
proposal_weights_gt
-
proposal_weights
,
min
=
0
)
torch
.
clamp
(
proposal_weights_gt
-
proposal_weights
,
min
=
0
)
)
**
2
/
(
proposal_weights
+
torch
.
finfo
(
torch
.
float32
).
eps
)
)
**
2
/
(
proposal_weights
+
torch
.
finfo
(
torch
.
float32
).
eps
)
loss_interval
=
loss_interval
.
mean
()
loss_interval
=
loss_interval
.
mean
()
loss
+=
loss_interval
*
1.0
loss
+=
loss_interval
*
1.0
...
@@ -390,7 +399,9 @@ if __name__ == "__main__":
...
@@ -390,7 +399,9 @@ if __name__ == "__main__":
if
step
%
100
==
0
:
if
step
%
100
==
0
:
elapsed_time
=
time
.
time
()
-
tic
elapsed_time
=
time
.
time
()
-
tic
loss
=
F
.
mse_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
])
loss
=
F
.
mse_loss
(
rgb
[
alive_ray_mask
],
pixels
[
alive_ray_mask
]
)
print
(
print
(
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s | step=
{
step
}
| "
f
"elapsed_time=
{
elapsed_time
:.
2
f
}
s | step=
{
step
}
| "
f
"loss=
{
loss
:.
5
f
}
| loss_interval=
{
loss_interval
:.
5
f
}
"
f
"loss=
{
loss
:.
5
f
}
| loss_interval=
{
loss_interval
:.
5
f
}
"
...
@@ -398,6 +409,8 @@ if __name__ == "__main__":
...
@@ -398,6 +409,8 @@ if __name__ == "__main__":
f
"n_rendering_samples=
{
n_rendering_samples
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
|"
f
"n_rendering_samples=
{
n_rendering_samples
:
d
}
| num_rays=
{
len
(
pixels
):
d
}
|"
)
)
_train
()
if
step
>=
0
and
step
%
1000
==
0
and
step
>
0
:
if
step
>=
0
and
step
%
1000
==
0
and
step
>
0
:
# evaluation
# evaluation
radiance_field
.
eval
()
radiance_field
.
eval
()
...
@@ -424,6 +437,7 @@ if __name__ == "__main__":
...
@@ -424,6 +437,7 @@ if __name__ == "__main__":
render_bkgd
=
render_bkgd
,
render_bkgd
=
render_bkgd
,
cone_angle
=
args
.
cone_angle
,
cone_angle
=
args
.
cone_angle
,
alpha_thre
=
alpha_thre
,
alpha_thre
=
alpha_thre
,
proposal_nets_require_grads
=
False
,
# test options
# test options
test_chunk_size
=
args
.
test_chunk_size
,
test_chunk_size
=
args
.
test_chunk_size
,
)
)
...
...
nerfacc/cuda/__init__.py
View file @
ad2a0079
...
@@ -23,6 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
...
@@ -23,6 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect
=
_make_lazy_cuda_func
(
"ray_aabb_intersect"
)
ray_aabb_intersect
=
_make_lazy_cuda_func
(
"ray_aabb_intersect"
)
ray_marching
=
_make_lazy_cuda_func
(
"ray_marching"
)
ray_marching
=
_make_lazy_cuda_func
(
"ray_marching"
)
ray_marching_with_grid
=
_make_lazy_cuda_func
(
"ray_marching_with_grid"
)
ray_resampling
=
_make_lazy_cuda_func
(
"ray_resampling"
)
ray_resampling
=
_make_lazy_cuda_func
(
"ray_resampling"
)
ray_pdf_query
=
_make_lazy_cuda_func
(
"ray_pdf_query"
)
ray_pdf_query
=
_make_lazy_cuda_func
(
"ray_pdf_query"
)
...
...
nerfacc/cuda/csrc/pybind.cu
View file @
ad2a0079
...
@@ -13,6 +13,15 @@ std::vector<torch::Tensor> ray_aabb_intersect(
...
@@ -13,6 +13,15 @@ std::vector<torch::Tensor> ray_aabb_intersect(
const
torch
::
Tensor
aabb
);
const
torch
::
Tensor
aabb
);
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_max
,
// sampling
const
float
step_size
,
const
float
cone_angle
);
std
::
vector
<
torch
::
Tensor
>
ray_marching_with_grid
(
// rays
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
rays_d
,
...
@@ -153,6 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
...
@@ -153,6 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// marching
// marching
m
.
def
(
"ray_aabb_intersect"
,
&
ray_aabb_intersect
);
m
.
def
(
"ray_aabb_intersect"
,
&
ray_aabb_intersect
);
m
.
def
(
"ray_marching"
,
&
ray_marching
);
m
.
def
(
"ray_marching"
,
&
ray_marching
);
m
.
def
(
"ray_marching_with_grid"
,
&
ray_marching_with_grid
);
m
.
def
(
"ray_resampling"
,
&
ray_resampling
);
m
.
def
(
"ray_resampling"
,
&
ray_resampling
);
m
.
def
(
"ray_pdf_query"
,
&
ray_pdf_query
);
m
.
def
(
"ray_pdf_query"
,
&
ray_pdf_query
);
...
...
nerfacc/cuda/csrc/ray_marching.cu
View file @
ad2a0079
...
@@ -76,7 +76,85 @@ inline __device__ __host__ float advance_to_next_voxel(
...
@@ -76,7 +76,85 @@ inline __device__ __host__ float advance_to_next_voxel(
// Raymarching
// Raymarching
// -------------------------------------------------------------------------------
// -------------------------------------------------------------------------------
__global__
void
ray_marching_kernel
(
__global__
void
ray_marching_kernel
(
// rays info
const
uint32_t
n_rays
,
const
float
*
rays_o
,
// shape (n_rays, 3)
const
float
*
rays_d
,
// shape (n_rays, 3)
const
float
*
t_min
,
// shape (n_rays,)
const
float
*
t_max
,
// shape (n_rays,)
// sampling
const
float
step_size
,
const
float
cone_angle
,
const
int
*
packed_info
,
// first round outputs
int
*
num_steps
,
// second round outputs
int
*
ray_indices
,
float
*
t_starts
,
float
*
t_ends
)
{
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
bool
is_first_round
=
(
packed_info
==
nullptr
);
// locate
rays_o
+=
i
*
3
;
rays_d
+=
i
*
3
;
t_min
+=
i
;
t_max
+=
i
;
if
(
is_first_round
)
{
num_steps
+=
i
;
}
else
{
int
base
=
packed_info
[
i
*
2
+
0
];
int
steps
=
packed_info
[
i
*
2
+
1
];
t_starts
+=
base
;
t_ends
+=
base
;
ray_indices
+=
base
;
}
const
float3
origin
=
make_float3
(
rays_o
[
0
],
rays_o
[
1
],
rays_o
[
2
]);
const
float3
dir
=
make_float3
(
rays_d
[
0
],
rays_d
[
1
],
rays_d
[
2
]);
const
float3
inv_dir
=
1.0
f
/
dir
;
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
float
dt_min
=
step_size
;
float
dt_max
=
1e10
f
;
int
j
=
0
;
float
t0
=
near
;
float
dt
=
calc_dt
(
t0
,
cone_angle
,
dt_min
,
dt_max
);
float
t1
=
t0
+
dt
;
float
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
while
(
t_mid
<
far
)
{
if
(
!
is_first_round
)
{
t_starts
[
j
]
=
t0
;
t_ends
[
j
]
=
t1
;
ray_indices
[
j
]
=
i
;
}
++
j
;
// march to next sample
t0
=
t1
;
t1
=
t0
+
calc_dt
(
t0
,
cone_angle
,
dt_min
,
dt_max
);
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
}
if
(
is_first_round
)
{
*
num_steps
=
j
;
}
return
;
}
__global__
void
ray_marching_with_grid_kernel
(
// rays info
// rays info
const
uint32_t
n_rays
,
const
uint32_t
n_rays
,
const
float
*
rays_o
,
// shape (n_rays, 3)
const
float
*
rays_o
,
// shape (n_rays, 3)
...
@@ -189,7 +267,84 @@ __global__ void ray_marching_kernel(
...
@@ -189,7 +267,84 @@ __global__ void ray_marching_kernel(
return
;
return
;
}
}
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_max
,
// sampling
const
float
step_size
,
const
float
cone_angle
)
{
DEVICE_GUARD
(
rays_o
);
CHECK_INPUT
(
rays_o
);
CHECK_INPUT
(
rays_d
);
CHECK_INPUT
(
t_min
);
CHECK_INPUT
(
t_max
);
TORCH_CHECK
(
rays_o
.
ndimension
()
==
2
&
rays_o
.
size
(
1
)
==
3
)
TORCH_CHECK
(
rays_d
.
ndimension
()
==
2
&
rays_d
.
size
(
1
)
==
3
)
TORCH_CHECK
(
t_min
.
ndimension
()
==
1
)
TORCH_CHECK
(
t_max
.
ndimension
()
==
1
)
const
int
n_rays
=
rays_o
.
size
(
0
);
const
int
threads
=
256
;
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
// helper counter
torch
::
Tensor
num_steps
=
torch
::
empty
(
{
n_rays
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// count number of samples per ray
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
// rays
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// sampling
step_size
,
cone_angle
,
nullptr
,
/* packed_info */
// outputs
num_steps
.
data_ptr
<
int
>
(),
nullptr
,
/* ray_indices */
nullptr
,
/* t_starts */
nullptr
/* t_ends */
);
torch
::
Tensor
cum_steps
=
num_steps
.
cumsum
(
0
,
torch
::
kInt32
);
torch
::
Tensor
packed_info
=
torch
::
stack
({
cum_steps
-
num_steps
,
num_steps
},
1
);
// output samples starts and ends
int
total_steps
=
cum_steps
[
cum_steps
.
size
(
0
)
-
1
].
item
<
int
>
();
torch
::
Tensor
t_starts
=
torch
::
empty
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
t_ends
=
torch
::
empty
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
ray_indices
=
torch
::
empty
({
total_steps
},
cum_steps
.
options
());
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
// rays
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// sampling
step_size
,
cone_angle
,
packed_info
.
data_ptr
<
int
>
(),
// outputs
nullptr
,
/* num_steps */
ray_indices
.
data_ptr
<
int
>
(),
t_starts
.
data_ptr
<
float
>
(),
t_ends
.
data_ptr
<
float
>
());
return
{
packed_info
,
ray_indices
,
t_starts
,
t_ends
};
}
std
::
vector
<
torch
::
Tensor
>
ray_marching_with_grid
(
// rays
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
rays_d
,
...
@@ -230,7 +385,7 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -230,7 +385,7 @@ std::vector<torch::Tensor> ray_marching(
{
n_rays
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
{
n_rays
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// count number of samples per ray
// count number of samples per ray
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
ray_marching_
with_grid_
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
// rays
// rays
n_rays
,
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_o
.
data_ptr
<
float
>
(),
...
@@ -261,7 +416,7 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -261,7 +416,7 @@ std::vector<torch::Tensor> ray_marching(
torch
::
Tensor
t_ends
=
torch
::
empty
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
t_ends
=
torch
::
empty
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
ray_indices
=
torch
::
empty
({
total_steps
},
cum_steps
.
options
());
torch
::
Tensor
ray_indices
=
torch
::
empty
({
total_steps
},
cum_steps
.
options
());
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
ray_marching_
with_grid_
kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
// rays
// rays
n_rays
,
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_o
.
data_ptr
<
float
>
(),
...
...
nerfacc/ray_marching.py
View file @
ad2a0079
...
@@ -5,11 +5,55 @@ import torch
...
@@ -5,11 +5,55 @@ import torch
import
nerfacc.cuda
as
_C
import
nerfacc.cuda
as
_C
from
.cdf
import
ray_resampling
from
.cdf
import
ray_resampling
from
.contraction
import
ContractionType
from
.grid
import
Grid
from
.grid
import
Grid
from
.intersection
import
ray_aabb_intersect
from
.intersection
import
ray_aabb_intersect
from
.pack
import
pack_info
,
unpack_info
from
.pack
import
pack_info
,
unpack_info
from
.vol_rendering
import
render_visibility
,
render_weight_from_density
from
.vol_rendering
import
(
render_visibility
,
render_weight_from_alpha
,
render_weight_from_density
,
)
@
torch
.
no_grad
()
def
maybe_filter
(
t_starts
:
torch
.
Tensor
,
t_ends
:
torch
.
Tensor
,
ray_indices
:
torch
.
Tensor
,
n_rays
:
int
,
# sigma/alpha function for skipping invisible space
sigma_fn
:
Optional
[
Callable
]
=
None
,
alpha_fn
:
Optional
[
Callable
]
=
None
,
net
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
):
alphas
=
None
if
sigma_fn
is
not
None
:
alpha_fn
=
lambda
*
args
:
1.0
-
torch
.
exp
(
-
sigma_fn
(
*
args
)
*
(
t_ends
-
t_starts
)
)
if
alpha_fn
is
not
None
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
)
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N, 1)! Got {}"
.
format
(
alphas
.
shape
)
# Compute visibility of the samples, and filter out invisible samples
masks
=
render_visibility
(
alphas
,
ray_indices
=
ray_indices
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
n_rays
=
n_rays
,
)
ray_indices
,
t_starts
,
t_ends
,
alphas
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
alphas
[
masks
],
)
return
ray_indices
,
t_starts
,
t_ends
,
alphas
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -30,6 +74,7 @@ def ray_marching(
...
@@ -30,6 +74,7 @@ def ray_marching(
proposal_nets
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
proposal_nets
:
Optional
[
torch
.
nn
.
Module
]
=
None
,
early_stop_eps
:
float
=
1e-4
,
early_stop_eps
:
float
=
1e-4
,
alpha_thre
:
float
=
0.0
,
alpha_thre
:
float
=
0.0
,
proposal_nets_require_grads
:
bool
=
True
,
# rendering options
# rendering options
near_plane
:
Optional
[
float
]
=
None
,
near_plane
:
Optional
[
float
]
=
None
,
far_plane
:
Optional
[
float
]
=
None
,
far_plane
:
Optional
[
float
]
=
None
,
...
@@ -132,6 +177,9 @@ def ray_marching(
...
@@ -132,6 +177,9 @@ def ray_marching(
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
"""
"""
torch
.
cuda
.
synchronize
()
n_rays
=
rays_o
.
shape
[
0
]
if
not
rays_o
.
is_cuda
:
if
not
rays_o
.
is_cuda
:
raise
NotImplementedError
(
"Only support cuda inputs."
)
raise
NotImplementedError
(
"Only support cuda inputs."
)
if
alpha_fn
is
not
None
and
sigma_fn
is
not
None
:
if
alpha_fn
is
not
None
and
sigma_fn
is
not
None
:
...
@@ -163,31 +211,30 @@ def ray_marching(
...
@@ -163,31 +211,30 @@ def ray_marching(
# use grid for skipping if given
# use grid for skipping if given
if
grid
is
not
None
:
if
grid
is
not
None
:
grid_roi_aabb
=
grid
.
roi_aabb
# marching with grid-based skipping
grid_binary
=
grid
.
binary
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching_with_grid
(
contraction_type
=
grid
.
contraction_type
.
to_cpp_version
()
# rays
else
:
rays_o
.
contiguous
(),
grid_roi_aabb
=
torch
.
tensor
(
rays_d
.
contiguous
(),
[
-
1e10
,
-
1e10
,
-
1e10
,
1e10
,
1e10
,
1e10
],
t_min
.
contiguous
(),
dtype
=
torch
.
float32
,
t_max
.
contiguous
(),
device
=
rays_o
.
device
,
# coontraction and grid
)
grid
.
roi_aabb
.
contiguous
(),
grid_binary
=
torch
.
ones
(
grid
.
binary
.
contiguous
(),
[
1
,
1
,
1
],
dtype
=
torch
.
bool
,
device
=
rays_o
.
device
grid
.
contraction_type
.
to_cpp_version
(),
# sampling
render_step_size
,
cone_angle
,
)
)
contraction_type
=
ContractionType
.
AABB
.
to_cpp_version
()
# marching with grid-based skipping
else
:
# marching
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching
(
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching
(
# rays
# rays
rays_o
.
contiguous
(),
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
rays_d
.
contiguous
(),
t_min
.
contiguous
(),
t_min
.
contiguous
(),
t_max
.
contiguous
(),
t_max
.
contiguous
(),
# coontraction and grid
grid_roi_aabb
.
contiguous
(),
grid_binary
.
contiguous
(),
contraction_type
,
# sampling
# sampling
render_step_size
,
render_step_size
,
cone_angle
,
cone_angle
,
...
@@ -197,96 +244,49 @@ def ray_marching(
...
@@ -197,96 +244,49 @@ def ray_marching(
if
proposal_nets
is
not
None
:
if
proposal_nets
is
not
None
:
# resample with proposal nets
# resample with proposal nets
for
net
,
num_samples
in
zip
(
proposal_nets
,
[
32
]):
for
net
,
num_samples
in
zip
(
proposal_nets
,
[
32
]):
with
torch
.
no_grad
():
ray_indices
,
t_starts
,
t_ends
,
alphas
=
maybe_filter
(
# skip invisible space
t_starts
=
t_starts
,
if
sigma_fn
is
not
None
or
alpha_fn
is
not
None
:
t_ends
=
t_ends
,
# Query sigma without gradients
if
sigma_fn
is
not
None
:
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N, 1)! Got {}"
.
format
(
sigmas
.
shape
)
alphas
=
1.0
-
torch
.
exp
(
-
sigmas
*
(
t_ends
-
t_starts
))
elif
alpha_fn
is
not
None
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N, 1)! Got {}"
.
format
(
alphas
.
shape
)
# Compute visibility of the samples, and filter out invisible samples
masks
=
render_visibility
(
alphas
,
ray_indices
=
ray_indices
,
ray_indices
=
ray_indices
,
n_rays
=
n_rays
,
sigma_fn
=
sigma_fn
,
alpha_fn
=
alpha_fn
,
net
=
net
,
early_stop_eps
=
early_stop_eps
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
alpha_thre
=
alpha_thre
,
n_rays
=
rays_o
.
shape
[
0
],
)
ray_indices
,
t_starts
,
t_ends
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
)
)
# print(
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
n_rays
)
# alphas.shape,
# masks.float().sum(),
# alphas.min(),
# alphas.max(),
# )
if
proposal_nets_require_grads
:
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
(),
net
=
net
)
weights
=
render_weight_from_density
(
weights
=
render_weight_from_density
(
t_starts
,
t_ends
,
sigmas
,
ray_indices
=
ray_indices
t_starts
,
t_ends
,
sigmas
,
ray_indices
=
ray_indices
)
)
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
rays_o
.
shape
[
0
])
proposal_sample_list
.
append
(
proposal_sample_list
.
append
(
(
packed_info
,
t_starts
,
t_ends
,
weights
)
(
packed_info
,
t_starts
,
t_ends
,
weights
)
)
)
else
:
weights
=
render_weight_from_alpha
(
alphas
,
ray_indices
=
ray_indices
)
packed_info
,
t_starts
,
t_ends
=
ray_resampling
(
packed_info
,
t_starts
,
t_ends
=
ray_resampling
(
packed_info
,
t_starts
,
t_ends
,
weights
,
n_samples
=
num_samples
packed_info
,
t_starts
,
t_ends
,
weights
,
n_samples
=
num_samples
)
)
ray_indices
=
unpack_info
(
packed_info
,
n_samples
=
t_starts
.
shape
[
0
])
ray_indices
=
unpack_info
(
packed_info
,
n_samples
=
t_starts
.
shape
[
0
])
with
torch
.
no_grad
():
ray_indices
,
t_starts
,
t_ends
,
_
=
maybe_filter
(
# skip invisible space
t_starts
=
t_starts
,
if
sigma_fn
is
not
None
or
alpha_fn
is
not
None
:
t_ends
=
t_ends
,
# Query sigma without gradients
if
sigma_fn
is
not
None
:
sigmas
=
sigma_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
assert
(
sigmas
.
shape
==
t_starts
.
shape
),
"sigmas must have shape of (N, 1)! Got {}"
.
format
(
sigmas
.
shape
)
alphas
=
1.0
-
torch
.
exp
(
-
sigmas
*
(
t_ends
-
t_starts
))
elif
alpha_fn
is
not
None
:
alphas
=
alpha_fn
(
t_starts
,
t_ends
,
ray_indices
.
long
())
assert
(
alphas
.
shape
==
t_starts
.
shape
),
"alphas must have shape of (N, 1)! Got {}"
.
format
(
alphas
.
shape
)
# Compute visibility of the samples, and filter out invisible samples
masks
=
render_visibility
(
alphas
,
ray_indices
=
ray_indices
,
ray_indices
=
ray_indices
,
n_rays
=
n_rays
,
sigma_fn
=
sigma_fn
,
alpha_fn
=
alpha_fn
,
net
=
None
,
early_stop_eps
=
early_stop_eps
,
early_stop_eps
=
early_stop_eps
,
alpha_thre
=
alpha_thre
,
alpha_thre
=
alpha_thre
,
n_rays
=
rays_o
.
shape
[
0
],
)
ray_indices
,
t_starts
,
t_ends
=
(
ray_indices
[
masks
],
t_starts
[
masks
],
t_ends
[
masks
],
)
)
if
proposal_nets
is
not
None
:
if
proposal_nets
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment