Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
1aeee0a9
"examples/profiling/igraph_bench.py" did not exist on "0d6504434befdf609d34709891eecf85f27e0934"
Commit
1aeee0a9
authored
Nov 19, 2022
by
Ruilong Li
Browse files
cleanup marching; resampling steps limit
parent
ad2a0079
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
42 deletions
+25
-42
nerfacc/cuda/csrc/cdf.cu
nerfacc/cuda/csrc/cdf.cu
+2
-1
nerfacc/cuda/csrc/pybind.cu
nerfacc/cuda/csrc/pybind.cu
+0
-2
nerfacc/cuda/csrc/ray_marching.cu
nerfacc/cuda/csrc/ray_marching.cu
+5
-22
nerfacc/ray_marching.py
nerfacc/ray_marching.py
+0
-2
nerfacc/vol_rendering.py
nerfacc/vol_rendering.py
+18
-15
No files found.
nerfacc/cuda/csrc/cdf.cu
View file @
1aeee0a9
...
@@ -265,7 +265,8 @@ std::vector<torch::Tensor> ray_resampling(
...
@@ -265,7 +265,8 @@ std::vector<torch::Tensor> ray_resampling(
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
torch
::
Tensor
num_steps
=
torch
::
split
(
packed_info
,
1
,
1
)[
1
];
torch
::
Tensor
num_steps
=
torch
::
split
(
packed_info
,
1
,
1
)[
1
];
torch
::
Tensor
resample_num_steps
=
(
num_steps
>
0
).
to
(
num_steps
.
options
())
*
steps
;
// torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch
::
Tensor
resample_num_steps
=
torch
::
clamp
(
num_steps
,
0
,
steps
);
torch
::
Tensor
resample_cum_steps
=
resample_num_steps
.
cumsum
(
0
,
torch
::
kInt32
);
torch
::
Tensor
resample_cum_steps
=
resample_num_steps
.
cumsum
(
0
,
torch
::
kInt32
);
torch
::
Tensor
resample_packed_info
=
torch
::
cat
(
torch
::
Tensor
resample_packed_info
=
torch
::
cat
(
{
resample_cum_steps
-
resample_num_steps
,
resample_num_steps
},
1
);
{
resample_cum_steps
-
resample_num_steps
,
resample_num_steps
},
1
);
...
...
nerfacc/cuda/csrc/pybind.cu
View file @
1aeee0a9
...
@@ -14,8 +14,6 @@ std::vector<torch::Tensor> ray_aabb_intersect(
...
@@ -14,8 +14,6 @@ std::vector<torch::Tensor> ray_aabb_intersect(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
// rays
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_max
,
const
torch
::
Tensor
t_max
,
// sampling
// sampling
...
...
nerfacc/cuda/csrc/ray_marching.cu
View file @
1aeee0a9
...
@@ -80,8 +80,6 @@ inline __device__ __host__ float advance_to_next_voxel(
...
@@ -80,8 +80,6 @@ inline __device__ __host__ float advance_to_next_voxel(
__global__
void
ray_marching_kernel
(
__global__
void
ray_marching_kernel
(
// rays info
// rays info
const
uint32_t
n_rays
,
const
uint32_t
n_rays
,
const
float
*
rays_o
,
// shape (n_rays, 3)
const
float
*
rays_d
,
// shape (n_rays, 3)
const
float
*
t_min
,
// shape (n_rays,)
const
float
*
t_min
,
// shape (n_rays,)
const
float
*
t_max
,
// shape (n_rays,)
const
float
*
t_max
,
// shape (n_rays,)
// sampling
// sampling
...
@@ -100,8 +98,6 @@ __global__ void ray_marching_kernel(
...
@@ -100,8 +98,6 @@ __global__ void ray_marching_kernel(
bool
is_first_round
=
(
packed_info
==
nullptr
);
bool
is_first_round
=
(
packed_info
==
nullptr
);
// locate
// locate
rays_o
+=
i
*
3
;
rays_d
+=
i
*
3
;
t_min
+=
i
;
t_min
+=
i
;
t_max
+=
i
;
t_max
+=
i
;
...
@@ -118,9 +114,6 @@ __global__ void ray_marching_kernel(
...
@@ -118,9 +114,6 @@ __global__ void ray_marching_kernel(
ray_indices
+=
base
;
ray_indices
+=
base
;
}
}
const
float3
origin
=
make_float3
(
rays_o
[
0
],
rays_o
[
1
],
rays_o
[
2
]);
const
float3
dir
=
make_float3
(
rays_d
[
0
],
rays_d
[
1
],
rays_d
[
2
]);
const
float3
inv_dir
=
1.0
f
/
dir
;
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
float
dt_min
=
step_size
;
float
dt_min
=
step_size
;
...
@@ -270,39 +263,31 @@ __global__ void ray_marching_with_grid_kernel(
...
@@ -270,39 +263,31 @@ __global__ void ray_marching_with_grid_kernel(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
// rays
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_max
,
const
torch
::
Tensor
t_max
,
// sampling
// sampling
const
float
step_size
,
const
float
step_size
,
const
float
cone_angle
)
const
float
cone_angle
)
{
{
DEVICE_GUARD
(
rays_o
);
DEVICE_GUARD
(
t_min
);
CHECK_INPUT
(
rays_o
);
CHECK_INPUT
(
rays_d
);
CHECK_INPUT
(
t_min
);
CHECK_INPUT
(
t_min
);
CHECK_INPUT
(
t_max
);
CHECK_INPUT
(
t_max
);
TORCH_CHECK
(
rays_o
.
ndimension
()
==
2
&
rays_o
.
size
(
1
)
==
3
)
TORCH_CHECK
(
rays_d
.
ndimension
()
==
2
&
rays_d
.
size
(
1
)
==
3
)
TORCH_CHECK
(
t_min
.
ndimension
()
==
1
)
TORCH_CHECK
(
t_min
.
ndimension
()
==
1
)
TORCH_CHECK
(
t_max
.
ndimension
()
==
1
)
TORCH_CHECK
(
t_max
.
ndimension
()
==
1
)
const
int
n_rays
=
rays_o
.
size
(
0
);
const
int
n_rays
=
t_min
.
size
(
0
);
const
int
threads
=
256
;
const
int
threads
=
256
;
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
// helper counter
// helper counter
torch
::
Tensor
num_steps
=
torch
::
empty
(
torch
::
Tensor
num_steps
=
torch
::
empty
(
{
n_rays
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
{
n_rays
},
t_min
.
options
().
dtype
(
torch
::
kInt32
));
// count number of samples per ray
// count number of samples per ray
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
// rays
// rays
n_rays
,
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// sampling
// sampling
...
@@ -320,15 +305,13 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -320,15 +305,13 @@ std::vector<torch::Tensor> ray_marching(
// output samples starts and ends
// output samples starts and ends
int
total_steps
=
cum_steps
[
cum_steps
.
size
(
0
)
-
1
].
item
<
int
>
();
int
total_steps
=
cum_steps
[
cum_steps
.
size
(
0
)
-
1
].
item
<
int
>
();
torch
::
Tensor
t_starts
=
torch
::
empty
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
t_starts
=
torch
::
empty
({
total_steps
,
1
},
t_min
.
options
());
torch
::
Tensor
t_ends
=
torch
::
empty
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
t_ends
=
torch
::
empty
({
total_steps
,
1
},
t_min
.
options
());
torch
::
Tensor
ray_indices
=
torch
::
empty
({
total_steps
},
cum_steps
.
options
());
torch
::
Tensor
ray_indices
=
torch
::
empty
({
total_steps
},
cum_steps
.
options
());
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
ray_marching_kernel
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
// rays
// rays
n_rays
,
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// sampling
// sampling
...
...
nerfacc/ray_marching.py
View file @
1aeee0a9
...
@@ -231,8 +231,6 @@ def ray_marching(
...
@@ -231,8 +231,6 @@ def ray_marching(
# marching
# marching
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching
(
packed_info
,
ray_indices
,
t_starts
,
t_ends
=
_C
.
ray_marching
(
# rays
# rays
rays_o
.
contiguous
(),
rays_d
.
contiguous
(),
t_min
.
contiguous
(),
t_min
.
contiguous
(),
t_max
.
contiguous
(),
t_max
.
contiguous
(),
# sampling
# sampling
...
...
nerfacc/vol_rendering.py
View file @
1aeee0a9
...
@@ -500,21 +500,24 @@ def render_visibility(
...
@@ -500,21 +500,24 @@ def render_visibility(
"""
"""
assert
(
assert
(
ray_indices
is
not
None
or
packed_info
is
not
None
alphas
.
dim
()
==
2
and
alphas
.
shape
[
-
1
]
==
1
),
"Either ray_indices or packed_info should be provided."
),
"alphas should be a 2D tensor with shape (n_samples, 1)."
if
ray_indices
is
not
None
and
_C
.
is_cub_available
():
visibility
=
alphas
>=
alpha_thre
transmittance
=
_RenderingTransmittanceFromAlphaCUB
.
apply
(
if
early_stop_eps
>
0
:
ray_indices
,
alphas
assert
(
)
ray_indices
is
not
None
or
packed_info
is
not
None
else
:
),
"Either ray_indices or packed_info should be provided."
if
packed_info
is
None
:
if
ray_indices
is
not
None
and
_C
.
is_cub_available
():
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
n_rays
)
transmittance
=
_RenderingTransmittanceFromAlphaCUB
.
apply
(
transmittance
=
_RenderingTransmittanceFromAlphaNaive
.
apply
(
ray_indices
,
alphas
packed_info
,
alphas
)
)
else
:
visibility
=
transmittance
>=
early_stop_eps
if
packed_info
is
None
:
if
alpha_thre
>
0
:
packed_info
=
pack_info
(
ray_indices
,
n_rays
=
n_rays
)
visibility
=
visibility
&
(
alphas
>=
alpha_thre
)
transmittance
=
_RenderingTransmittanceFromAlphaNaive
.
apply
(
packed_info
,
alphas
)
visibility
=
visibility
&
(
transmittance
>=
early_stop_eps
)
visibility
=
visibility
.
squeeze
(
-
1
)
visibility
=
visibility
.
squeeze
(
-
1
)
return
visibility
return
visibility
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment