Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
c3ab153e
Commit
c3ab153e
authored
Sep 12, 2022
by
Ruilong Li
Browse files
cumsum for marching
parent
86b90ea6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
151 additions
and
142 deletions
+151
-142
README.md
README.md
+3
-3
examples/trainval.py
examples/trainval.py
+2
-2
nerfacc/cuda/csrc/pybind.cu
nerfacc/cuda/csrc/pybind.cu
+28
-15
nerfacc/cuda/csrc/ray_marching.cu
nerfacc/cuda/csrc/ray_marching.cu
+96
-85
nerfacc/cuda/csrc/vol_rendering.cu
nerfacc/cuda/csrc/vol_rendering.cu
+8
-11
nerfacc/cuda/csrc/volumetric_weights.cu
nerfacc/cuda/csrc/volumetric_weights.cu
+13
-15
nerfacc/volumetric_rendering.py
nerfacc/volumetric_rendering.py
+1
-11
No files found.
README.md
View file @
c3ab153e
...
@@ -12,9 +12,9 @@ python examples/trainval.py
...
@@ -12,9 +12,9 @@ python examples/trainval.py
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - |
| - | - | - | - |
| Time | 3
77
s | 357s | 354s |
| Time | 3
25
s
| 357s
| 354s
|
| PSNR | 36.0
8
| 36.5
8
| 29.63 |
| PSNR | 36.
2
0 | 36.5
5
| 29.63 |
| FPS | 12.56 | 25.54 |
Tested with the default settings on the Lego test set.
Tested with the default settings on the Lego test set.
...
...
examples/trainval.py
View file @
c3ab153e
...
@@ -90,14 +90,14 @@ if __name__ == "__main__":
...
@@ -90,14 +90,14 @@ if __name__ == "__main__":
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
device
=
"cuda:0"
device
=
"cuda:0"
scene
=
"
lego
"
scene
=
"
materials
"
# setup dataset
# setup dataset
train_dataset
=
SubjectLoader
(
train_dataset
=
SubjectLoader
(
subject_id
=
scene
,
subject_id
=
scene
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
root_fp
=
"/home/ruilongli/data/nerf_synthetic/"
,
split
=
"trainval"
,
split
=
"trainval"
,
num_rays
=
4096
,
num_rays
=
1024
,
)
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
train_dataset
.
images
=
train_dataset
.
images
.
to
(
device
)
...
...
nerfacc/cuda/csrc/pybind.cu
View file @
c3ab153e
...
@@ -8,21 +8,21 @@ std::vector<torch::Tensor> ray_aabb_intersect(
...
@@ -8,21 +8,21 @@ std::vector<torch::Tensor> ray_aabb_intersect(
);
);
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
//
std::vector<torch::Tensor> ray_marching(
// rays
//
// rays
const
torch
::
Tensor
rays_o
,
//
const torch::Tensor rays_o,
const
torch
::
Tensor
rays_d
,
//
const torch::Tensor rays_d,
const
torch
::
Tensor
t_min
,
//
const torch::Tensor t_min,
const
torch
::
Tensor
t_max
,
//
const torch::Tensor t_max,
// density grid
//
// density grid
const
torch
::
Tensor
aabb
,
//
const torch::Tensor aabb,
const
pybind11
::
list
resolution
,
//
const pybind11::list resolution,
const
torch
::
Tensor
occ_binary
,
//
const torch::Tensor occ_binary,
// sampling
//
// sampling
const
int
max_total_samples
,
//
const int max_total_samples,
const
int
max_per_ray_samples
,
//
const int max_per_ray_samples,
const
float
dt
//
const float dt
);
//
);
std
::
vector
<
torch
::
Tensor
>
volumetric_rendering_inference
(
std
::
vector
<
torch
::
Tensor
>
volumetric_rendering_inference
(
torch
::
Tensor
packed_info
,
torch
::
Tensor
packed_info
,
...
@@ -69,6 +69,19 @@ torch::Tensor volumetric_weights_backward(
...
@@ -69,6 +69,19 @@ torch::Tensor volumetric_weights_backward(
torch
::
Tensor
sigmas
torch
::
Tensor
sigmas
);
);
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_d
,
const
torch
::
Tensor
t_min
,
const
torch
::
Tensor
t_max
,
// density grid
const
torch
::
Tensor
aabb
,
const
pybind11
::
list
resolution
,
const
torch
::
Tensor
occ_binary
,
// sampling
const
float
dt
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
...
...
nerfacc/cuda/csrc/ray_marching.cu
View file @
c3ab153e
...
@@ -61,7 +61,8 @@ inline __device__ float advance_to_next_voxel(
...
@@ -61,7 +61,8 @@ inline __device__ float advance_to_next_voxel(
}
}
__global__
void
kernel_raymarching
(
__global__
void
marching_steps_kernel
(
// rays info
// rays info
const
uint32_t
n_rays
,
const
uint32_t
n_rays
,
const
float
*
rays_o
,
// shape (n_rays, 3)
const
float
*
rays_o
,
// shape (n_rays, 3)
...
@@ -75,18 +76,9 @@ __global__ void kernel_raymarching(
...
@@ -75,18 +76,9 @@ __global__ void kernel_raymarching(
const
int
resz
,
const
int
resz
,
const
bool
*
occ_binary
,
// shape (reso_x, reso_y, reso_z)
const
bool
*
occ_binary
,
// shape (reso_x, reso_y, reso_z)
// sampling
// sampling
const
int
max_total_samples
,
const
int
max_per_ray_samples
,
const
float
dt
,
const
float
dt
,
// writable helpers
// outputs
int
*
steps_counter
,
int
*
num_steps
int
*
rays_counter
,
// frustrum outputs
int
*
packed_info
,
float
*
frustum_origins
,
float
*
frustum_dirs
,
float
*
frustum_starts
,
float
*
frustum_ends
)
{
)
{
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
...
@@ -95,23 +87,19 @@ __global__ void kernel_raymarching(
...
@@ -95,23 +87,19 @@ __global__ void kernel_raymarching(
rays_d
+=
i
*
3
;
rays_d
+=
i
*
3
;
t_min
+=
i
;
t_min
+=
i
;
t_max
+=
i
;
t_max
+=
i
;
num_steps
+=
i
;
const
float
ox
=
rays_o
[
0
],
oy
=
rays_o
[
1
],
oz
=
rays_o
[
2
];
const
float
ox
=
rays_o
[
0
],
oy
=
rays_o
[
1
],
oz
=
rays_o
[
2
];
const
float
dx
=
rays_d
[
0
],
dy
=
rays_d
[
1
],
dz
=
rays_d
[
2
];
const
float
dx
=
rays_d
[
0
],
dy
=
rays_d
[
1
],
dz
=
rays_d
[
2
];
const
float
rdx
=
1
/
dx
,
rdy
=
1
/
dy
,
rdz
=
1
/
dz
;
const
float
rdx
=
1
/
dx
,
rdy
=
1
/
dy
,
rdz
=
1
/
dz
;
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
uint32_t
ray_idx
,
base
,
marching_samples
;
int
j
=
0
;
uint32_t
j
;
float
t0
=
near
;
// TODO(ruilongli): perturb `near` as in ngp_pl?
float
t0
,
t1
,
t_mid
;
float
t1
=
t0
+
dt
;
float
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
// first pass to compute an accurate number of steps
j
=
0
;
t0
=
near
;
// TODO(ruilongli): perturb `near` as in ngp_pl?
t1
=
t0
+
dt
;
t_mid
=
(
t0
+
t1
)
*
0.5
f
;
while
(
t_mid
<
far
&&
j
<
max_per_ray_samples
)
{
while
(
t_mid
<
far
)
{
// current center
// current center
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
y
=
oy
+
t_mid
*
dy
;
const
float
y
=
oy
+
t_mid
*
dy
;
...
@@ -135,10 +123,47 @@ __global__ void kernel_raymarching(
...
@@ -135,10 +123,47 @@ __global__ void kernel_raymarching(
}
}
if
(
j
==
0
)
return
;
if
(
j
==
0
)
return
;
marching_samples
=
j
;
num_steps
[
0
]
=
j
;
base
=
atomicAdd
(
steps_counter
,
marching_samples
);
return
;
if
(
base
+
marching_samples
>
max_total_samples
)
return
;
}
ray_idx
=
atomicAdd
(
rays_counter
,
1
);
__global__
void
marching_forward_kernel
(
// rays info
const
uint32_t
n_rays
,
const
float
*
rays_o
,
// shape (n_rays, 3)
const
float
*
rays_d
,
// shape (n_rays, 3)
const
float
*
t_min
,
// shape (n_rays,)
const
float
*
t_max
,
// shape (n_rays,)
// density grid
const
float
*
aabb
,
// [min_x, min_y, min_z, max_x, max_y, max_y]
const
int
resx
,
const
int
resy
,
const
int
resz
,
const
bool
*
occ_binary
,
// shape (reso_x, reso_y, reso_z)
// sampling
const
float
dt
,
const
int
*
packed_info
,
// frustrum outputs
float
*
frustum_origins
,
float
*
frustum_dirs
,
float
*
frustum_starts
,
float
*
frustum_ends
)
{
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
// locate
rays_o
+=
i
*
3
;
rays_d
+=
i
*
3
;
t_min
+=
i
;
t_max
+=
i
;
int
base
=
packed_info
[
i
*
2
+
0
];
int
steps
=
packed_info
[
i
*
2
+
1
];
const
float
ox
=
rays_o
[
0
],
oy
=
rays_o
[
1
],
oz
=
rays_o
[
2
];
const
float
dx
=
rays_d
[
0
],
dy
=
rays_d
[
1
],
dz
=
rays_d
[
2
];
const
float
rdx
=
1
/
dx
,
rdy
=
1
/
dy
,
rdz
=
1
/
dz
;
const
float
near
=
t_min
[
0
],
far
=
t_max
[
0
];
// locate
// locate
frustum_origins
+=
base
*
3
;
frustum_origins
+=
base
*
3
;
...
@@ -146,13 +171,12 @@ __global__ void kernel_raymarching(
...
@@ -146,13 +171,12 @@ __global__ void kernel_raymarching(
frustum_starts
+=
base
;
frustum_starts
+=
base
;
frustum_ends
+=
base
;
frustum_ends
+=
base
;
// Second round
int
j
=
0
;
j
=
0
;
float
t0
=
near
;
t0
=
near
;
float
t1
=
t0
+
dt
;
t1
=
t0
+
dt
;
float
t_mid
=
(
t0
+
t1
)
/
2.
;
t_mid
=
(
t0
+
t1
)
/
2.
;
while
(
t_mid
<
far
&&
j
<
marching_samples
)
{
while
(
t_mid
<
far
)
{
// current center
// current center
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
x
=
ox
+
t_mid
*
dx
;
const
float
y
=
oy
+
t_mid
*
dy
;
const
float
y
=
oy
+
t_mid
*
dy
;
...
@@ -182,43 +206,13 @@ __global__ void kernel_raymarching(
...
@@ -182,43 +206,13 @@ __global__ void kernel_raymarching(
t1
=
t_mid
+
dt
*
0.5
f
;
t1
=
t_mid
+
dt
*
0.5
f
;
}
}
}
}
if
(
j
!=
steps
)
{
packed_info
[
ray_idx
*
3
+
0
]
=
i
;
// ray idx in {rays_o, rays_d}
printf
(
"WTF %d v.s. %d
\n
"
,
j
,
steps
);
packed_info
[
ray_idx
*
3
+
1
]
=
base
;
// point idx start.
}
packed_info
[
ray_idx
*
3
+
2
]
=
j
;
// point idx shift (actual marching samples).
return
;
return
;
}
}
/**
* @brief Sample points by ray marching.
*
* @param rays_o Ray origins Shape of [n_rays, 3].
* @param rays_d Normalized ray directions. Shape of [n_rays, 3].
* @param t_min Near planes of rays. Shape of [n_rays].
* @param t_max Far planes of rays. Shape of [n_rays].
* @param grid_center Density grid center. TODO: support 3-dims.
* @param grid_scale Density grid base level scale. TODO: support 3-dims.
* @param grid_cascades Density grid levels.
* @param grid_size Density grid resolution.
* @param grid_bitfield Density grid uint8 bit field.
* @param marching_steps Marching steps during inference.
* @param max_total_samples Maximum total number of samples in this batch.
* @param max_ray_samples Used to define the minimal step size: SQRT3() / max_ray_samples.
* @param cone_angle 0. for nerf-synthetic and 1./256 for real scenes.
* @param step_scale Scale up the step size by this much. Usually equals to scene scale.
* @return std::vector<torch::Tensor>
* - packed_info: Stores how to index the ray samples from the returned values.
* Shape of [n_rays, 3]. First value is the ray index. Second value is the sample
* start index in the results for this ray. Third value is the number of samples for
* this ray. Note for rays that have zero samples, we simply skip them so the `packed_info`
* has some zero padding in the end.
* - origins: Ray origins for those samples. [max_total_samples, 3]
* - dirs: Ray directions for those samples. [max_total_samples, 3]
* - starts: Where the frustum-shape sample starts along a ray. [max_total_samples, 1]
* - ends: Where the frustum-shape sample ends along a ray. [max_total_samples, 1]
*/
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
std
::
vector
<
torch
::
Tensor
>
ray_marching
(
// rays
// rays
const
torch
::
Tensor
rays_o
,
const
torch
::
Tensor
rays_o
,
...
@@ -230,8 +224,6 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -230,8 +224,6 @@ std::vector<torch::Tensor> ray_marching(
const
pybind11
::
list
resolution
,
const
pybind11
::
list
resolution
,
const
torch
::
Tensor
occ_binary
,
const
torch
::
Tensor
occ_binary
,
// sampling
// sampling
const
int
max_total_samples
,
const
int
max_per_ray_samples
,
const
float
dt
const
float
dt
)
{
)
{
DEVICE_GUARD
(
rays_o
);
DEVICE_GUARD
(
rays_o
);
...
@@ -249,20 +241,43 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -249,20 +241,43 @@ std::vector<torch::Tensor> ray_marching(
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
const
int
blocks
=
CUDA_N_BLOCKS_NEEDED
(
n_rays
,
threads
);
// helper counter
// helper counter
torch
::
Tensor
steps_counter
=
torch
::
zeros
(
torch
::
Tensor
num_steps
=
torch
::
zeros
(
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
{
n_rays
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
torch
::
Tensor
rays_counter
=
torch
::
zeros
(
{
1
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// count number of samples per ray
marching_steps_kernel
<<<
blocks
,
threads
>>>
(
// rays
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_d
.
data_ptr
<
float
>
(),
t_min
.
data_ptr
<
float
>
(),
t_max
.
data_ptr
<
float
>
(),
// density grid
aabb
.
data_ptr
<
float
>
(),
resolution
[
0
].
cast
<
int
>
(),
resolution
[
1
].
cast
<
int
>
(),
resolution
[
2
].
cast
<
int
>
(),
occ_binary
.
data_ptr
<
bool
>
(),
// sampling
dt
,
// writable helpers
num_steps
.
data_ptr
<
int
>
()
);
torch
::
Tensor
cum_steps
=
num_steps
.
cumsum
(
0
,
torch
::
kInt32
);
torch
::
Tensor
packed_info
=
torch
::
stack
({
cum_steps
-
num_steps
,
num_steps
},
1
);
// std::cout << "num_steps" << num_steps.dtype() << std::endl;
// std::cout << "cum_steps" << cum_steps.dtype() << std::endl;
// std::cout << "packed_info" << packed_info.dtype() << std::endl;
// output frustum samples
// output frustum samples
torch
::
Tensor
packed_info
=
torch
::
zeros
(
int
total_steps
=
cum_steps
[
cum_steps
.
size
(
0
)
-
1
].
item
<
int
>
();
{
n_rays
,
3
},
rays_o
.
options
().
dtype
(
torch
::
kInt32
));
// ray_id, sample_id, num_samples
torch
::
Tensor
frustum_origins
=
torch
::
zeros
({
total_steps
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_origins
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_dirs
=
torch
::
zeros
({
total_steps
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_dirs
=
torch
::
zeros
({
max_total_samples
,
3
},
rays_o
.
options
());
torch
::
Tensor
frustum_starts
=
torch
::
zeros
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
frustum_starts
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
torch
::
Tensor
frustum_ends
=
torch
::
zeros
({
total_steps
,
1
},
rays_o
.
options
());
torch
::
Tensor
frustum_ends
=
torch
::
zeros
({
max_total_samples
,
1
},
rays_o
.
options
());
kernel_raymarching
<<<
blocks
,
threads
>>>
(
marching_forward_kernel
<<<
blocks
,
threads
>>>
(
// rays
// rays
n_rays
,
n_rays
,
rays_o
.
data_ptr
<
float
>
(),
rays_o
.
data_ptr
<
float
>
(),
...
@@ -276,19 +291,15 @@ std::vector<torch::Tensor> ray_marching(
...
@@ -276,19 +291,15 @@ std::vector<torch::Tensor> ray_marching(
resolution
[
2
].
cast
<
int
>
(),
resolution
[
2
].
cast
<
int
>
(),
occ_binary
.
data_ptr
<
bool
>
(),
occ_binary
.
data_ptr
<
bool
>
(),
// sampling
// sampling
max_total_samples
,
max_per_ray_samples
,
dt
,
dt
,
// writable helpers
packed_info
.
data_ptr
<
int
>
(),
steps_counter
.
data_ptr
<
int
>
(),
// total samples.
// outputs
rays_counter
.
data_ptr
<
int
>
(),
// total rays.
packed_info
.
data_ptr
<
int
>
(),
frustum_origins
.
data_ptr
<
float
>
(),
frustum_origins
.
data_ptr
<
float
>
(),
frustum_dirs
.
data_ptr
<
float
>
(),
frustum_dirs
.
data_ptr
<
float
>
(),
frustum_starts
.
data_ptr
<
float
>
(),
frustum_starts
.
data_ptr
<
float
>
(),
frustum_ends
.
data_ptr
<
float
>
()
frustum_ends
.
data_ptr
<
float
>
()
);
);
return
{
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
,
steps_counter
};
return
{
packed_info
,
frustum_origins
,
frustum_dirs
,
frustum_starts
,
frustum_ends
};
}
}
nerfacc/cuda/csrc/vol_rendering.cu
View file @
c3ab153e
...
@@ -16,10 +16,9 @@ __global__ void volumetric_rendering_inference_kernel(
...
@@ -16,10 +16,9 @@ __global__ void volumetric_rendering_inference_kernel(
CUDA_GET_THREAD_ID
(
thread_id
,
n_rays
);
CUDA_GET_THREAD_ID
(
thread_id
,
n_rays
);
// locate
// locate
const
int
i
=
packed_info
[
thread_id
*
3
+
0
];
// ray idx in {rays_o, rays_d}
const
int
base
=
packed_info
[
thread_id
*
2
+
0
];
// point idx start.
const
int
base
=
packed_info
[
thread_id
*
3
+
1
];
// point idx start.
const
int
steps
=
packed_info
[
thread_id
*
2
+
1
];
// point idx shift.
const
int
numsteps
=
packed_info
[
thread_id
*
3
+
2
];
// point idx shift.
if
(
steps
==
0
)
return
;
if
(
numsteps
==
0
)
return
;
starts
+=
base
;
starts
+=
base
;
ends
+=
base
;
ends
+=
base
;
...
@@ -29,7 +28,7 @@ __global__ void volumetric_rendering_inference_kernel(
...
@@ -29,7 +28,7 @@ __global__ void volumetric_rendering_inference_kernel(
scalar_t
T
=
1.
f
;
scalar_t
T
=
1.
f
;
scalar_t
EPSILON
=
1e-4
f
;
scalar_t
EPSILON
=
1e-4
f
;
int
j
=
0
;
int
j
=
0
;
for
(;
j
<
num
steps
;
++
j
)
{
for
(;
j
<
steps
;
++
j
)
{
if
(
T
<
EPSILON
)
{
if
(
T
<
EPSILON
)
{
break
;
break
;
}
}
...
@@ -46,10 +45,8 @@ __global__ void volumetric_rendering_inference_kernel(
...
@@ -46,10 +45,8 @@ __global__ void volumetric_rendering_inference_kernel(
compact_selector
[
k
]
=
base
+
k
;
compact_selector
[
k
]
=
base
+
k
;
}
}
compact_packed_info
+=
thread_id
*
3
;
compact_packed_info
[
thread_id
*
2
+
0
]
=
compact_base
;
// compact point idx start.
compact_packed_info
[
0
]
=
i
;
// ray idx in {rays_o, rays_d}
compact_packed_info
[
thread_id
*
2
+
1
]
=
j
;
// compact point idx shift.
compact_packed_info
[
1
]
=
compact_base
;
// compact point idx start.
compact_packed_info
[
2
]
=
j
;
// compact point idx shift.
}
}
...
@@ -201,7 +198,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
...
@@ -201,7 +198,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
CHECK_INPUT
(
starts
);
CHECK_INPUT
(
starts
);
CHECK_INPUT
(
ends
);
CHECK_INPUT
(
ends
);
CHECK_INPUT
(
sigmas
);
CHECK_INPUT
(
sigmas
);
TORCH_CHECK
(
packed_info
.
ndimension
()
==
2
&
packed_info
.
size
(
1
)
==
3
);
TORCH_CHECK
(
packed_info
.
ndimension
()
==
2
&
packed_info
.
size
(
1
)
==
2
);
TORCH_CHECK
(
starts
.
ndimension
()
==
2
&
starts
.
size
(
1
)
==
1
);
TORCH_CHECK
(
starts
.
ndimension
()
==
2
&
starts
.
size
(
1
)
==
1
);
TORCH_CHECK
(
ends
.
ndimension
()
==
2
&
ends
.
size
(
1
)
==
1
);
TORCH_CHECK
(
ends
.
ndimension
()
==
2
&
ends
.
size
(
1
)
==
1
);
TORCH_CHECK
(
sigmas
.
ndimension
()
==
2
&
sigmas
.
size
(
1
)
==
1
);
TORCH_CHECK
(
sigmas
.
ndimension
()
==
2
&
sigmas
.
size
(
1
)
==
1
);
...
@@ -217,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
...
@@ -217,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
{
1
},
packed_info
.
options
().
dtype
(
torch
::
kInt32
));
{
1
},
packed_info
.
options
().
dtype
(
torch
::
kInt32
));
// outputs
// outputs
torch
::
Tensor
compact_packed_info
=
torch
::
zeros
({
n_rays
,
3
},
packed_info
.
options
());
torch
::
Tensor
compact_packed_info
=
torch
::
zeros
({
n_rays
,
2
},
packed_info
.
options
());
torch
::
Tensor
compact_selector
=
-
torch
::
ones
({
n_samples
},
packed_info
.
options
());
torch
::
Tensor
compact_selector
=
-
torch
::
ones
({
n_samples
},
packed_info
.
options
());
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
...
...
nerfacc/cuda/csrc/volumetric_weights.cu
View file @
c3ab153e
...
@@ -13,13 +13,12 @@ __global__ void volumetric_weights_forward_kernel(
...
@@ -13,13 +13,12 @@ __global__ void volumetric_weights_forward_kernel(
int
*
samples_ray_ids
,
// output
int
*
samples_ray_ids
,
// output
bool
*
mask
// output
bool
*
mask
// output
)
{
)
{
CUDA_GET_THREAD_ID
(
thread_id
,
n_rays
);
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
// locate
// locate
const
int
i
=
packed_info
[
thread_id
*
3
+
0
];
// ray idx in {rays_o, rays_d}
const
int
base
=
packed_info
[
i
*
2
+
0
];
// point idx start.
const
int
base
=
packed_info
[
thread_id
*
3
+
1
];
// point idx start.
const
int
steps
=
packed_info
[
i
*
2
+
1
];
// point idx shift.
const
int
numsteps
=
packed_info
[
thread_id
*
3
+
2
];
// point idx shift.
if
(
steps
==
0
)
return
;
if
(
numsteps
==
0
)
return
;
starts
+=
base
;
starts
+=
base
;
ends
+=
base
;
ends
+=
base
;
...
@@ -28,14 +27,14 @@ __global__ void volumetric_weights_forward_kernel(
...
@@ -28,14 +27,14 @@ __global__ void volumetric_weights_forward_kernel(
samples_ray_ids
+=
base
;
samples_ray_ids
+=
base
;
mask
+=
i
;
mask
+=
i
;
for
(
int
j
=
0
;
j
<
num
steps
;
++
j
)
{
for
(
int
j
=
0
;
j
<
steps
;
++
j
)
{
samples_ray_ids
[
j
]
=
i
;
samples_ray_ids
[
j
]
=
i
;
}
}
// accumulated rendering
// accumulated rendering
scalar_t
T
=
1.
f
;
scalar_t
T
=
1.
f
;
scalar_t
EPSILON
=
1e-4
f
;
scalar_t
EPSILON
=
1e-4
f
;
for
(
int
j
=
0
;
j
<
num
steps
;
++
j
)
{
for
(
int
j
=
0
;
j
<
steps
;
++
j
)
{
if
(
T
<
EPSILON
)
{
if
(
T
<
EPSILON
)
{
break
;
break
;
}
}
...
@@ -60,13 +59,12 @@ __global__ void volumetric_weights_backward_kernel(
...
@@ -60,13 +59,12 @@ __global__ void volumetric_weights_backward_kernel(
const
scalar_t
*
grad_weights
,
// input
const
scalar_t
*
grad_weights
,
// input
scalar_t
*
grad_sigmas
// output
scalar_t
*
grad_sigmas
// output
)
{
)
{
CUDA_GET_THREAD_ID
(
thread_id
,
n_rays
);
CUDA_GET_THREAD_ID
(
i
,
n_rays
);
// locate
// locate
// const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const
int
base
=
packed_info
[
i
*
2
+
0
];
// point idx start.
const
int
base
=
packed_info
[
thread_id
*
3
+
1
];
// point idx start.
const
int
steps
=
packed_info
[
i
*
2
+
1
];
// point idx shift.
const
int
numsteps
=
packed_info
[
thread_id
*
3
+
2
];
// point idx shift.
if
(
steps
==
0
)
return
;
if
(
numsteps
==
0
)
return
;
starts
+=
base
;
starts
+=
base
;
ends
+=
base
;
ends
+=
base
;
...
@@ -76,14 +74,14 @@ __global__ void volumetric_weights_backward_kernel(
...
@@ -76,14 +74,14 @@ __global__ void volumetric_weights_backward_kernel(
grad_sigmas
+=
base
;
grad_sigmas
+=
base
;
scalar_t
accum
=
0
;
scalar_t
accum
=
0
;
for
(
int
j
=
0
;
j
<
num
steps
;
++
j
)
{
for
(
int
j
=
0
;
j
<
steps
;
++
j
)
{
accum
+=
grad_weights
[
j
]
*
weights
[
j
];
accum
+=
grad_weights
[
j
]
*
weights
[
j
];
}
}
// backward of accumulated rendering
// backward of accumulated rendering
scalar_t
T
=
1.
f
;
scalar_t
T
=
1.
f
;
scalar_t
EPSILON
=
1e-4
f
;
scalar_t
EPSILON
=
1e-4
f
;
for
(
int
j
=
0
;
j
<
num
steps
;
++
j
)
{
for
(
int
j
=
0
;
j
<
steps
;
++
j
)
{
if
(
T
<
EPSILON
)
{
if
(
T
<
EPSILON
)
{
break
;
break
;
}
}
...
@@ -108,7 +106,7 @@ std::vector<torch::Tensor> volumetric_weights_forward(
...
@@ -108,7 +106,7 @@ std::vector<torch::Tensor> volumetric_weights_forward(
CHECK_INPUT
(
starts
);
CHECK_INPUT
(
starts
);
CHECK_INPUT
(
ends
);
CHECK_INPUT
(
ends
);
CHECK_INPUT
(
sigmas
);
CHECK_INPUT
(
sigmas
);
TORCH_CHECK
(
packed_info
.
ndimension
()
==
2
&
packed_info
.
size
(
1
)
==
3
);
TORCH_CHECK
(
packed_info
.
ndimension
()
==
2
&
packed_info
.
size
(
1
)
==
2
);
TORCH_CHECK
(
starts
.
ndimension
()
==
2
&
starts
.
size
(
1
)
==
1
);
TORCH_CHECK
(
starts
.
ndimension
()
==
2
&
starts
.
size
(
1
)
==
1
);
TORCH_CHECK
(
ends
.
ndimension
()
==
2
&
ends
.
size
(
1
)
==
1
);
TORCH_CHECK
(
ends
.
ndimension
()
==
2
&
ends
.
size
(
1
)
==
1
);
TORCH_CHECK
(
sigmas
.
ndimension
()
==
2
&
sigmas
.
size
(
1
)
==
1
);
TORCH_CHECK
(
sigmas
.
ndimension
()
==
2
&
sigmas
.
size
(
1
)
==
1
);
...
...
nerfacc/volumetric_rendering.py
View file @
c3ab153e
...
@@ -54,7 +54,6 @@ def volumetric_rendering(
...
@@ -54,7 +54,6 @@ def volumetric_rendering(
frustum_dirs
,
frustum_dirs
,
frustum_starts
,
frustum_starts
,
frustum_ends
,
frustum_ends
,
steps_counter
,
)
=
ray_marching
(
)
=
ray_marching
(
# rays
# rays
rays_o
,
rays_o
,
...
@@ -66,22 +65,13 @@ def volumetric_rendering(
...
@@ -66,22 +65,13 @@ def volumetric_rendering(
scene_resolution
,
scene_resolution
,
scene_occ_binary
,
scene_occ_binary
,
# sampling
# sampling
render_total_samples
,
render_n_samples
,
render_step_size
,
render_step_size
,
)
)
# squeeze valid samples
total_samples
=
max
(
packed_info
[:,
-
1
].
sum
(),
1
)
total_samples
=
int
(
math
.
ceil
(
total_samples
/
256.0
))
*
256
frustum_origins
=
frustum_origins
[:
total_samples
]
frustum_dirs
=
frustum_dirs
[:
total_samples
]
frustum_starts
=
frustum_starts
[:
total_samples
]
frustum_ends
=
frustum_ends
[:
total_samples
]
frustum_positions
=
(
frustum_positions
=
(
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
frustum_origins
+
frustum_dirs
*
(
frustum_starts
+
frustum_ends
)
/
2.0
)
)
steps_counter
=
packed_info
[:,
-
1
].
sum
(
0
,
keepdim
=
True
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
densities
=
query_fn
(
densities
=
query_fn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment