Commit c3ab153e authored by Ruilong Li's avatar Ruilong Li
Browse files

cumsum for marching

parent 86b90ea6
......@@ -12,9 +12,9 @@ python examples/trainval.py
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - |
| Time | 377s | 357s | 354s |
| PSNR | 36.08 | 36.58 | 29.63 |
| Time | 325s | 357s | 354s |
| PSNR | 36.20 | 36.55 | 29.63 |
| FPS | 12.56 | 25.54 |
Tested with the default settings on the Lego test set.
......
......@@ -90,14 +90,14 @@ if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda:0"
scene = "lego"
scene = "materials"
# setup dataset
train_dataset = SubjectLoader(
subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval",
num_rays=4096,
num_rays=1024,
)
train_dataset.images = train_dataset.images.to(device)
......
......@@ -8,21 +8,21 @@ std::vector<torch::Tensor> ray_aabb_intersect(
);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt
);
// std::vector<torch::Tensor> ray_marching(
// // rays
// const torch::Tensor rays_o,
// const torch::Tensor rays_d,
// const torch::Tensor t_min,
// const torch::Tensor t_max,
// // density grid
// const torch::Tensor aabb,
// const pybind11::list resolution,
// const torch::Tensor occ_binary,
// // sampling
// const int max_total_samples,
// const int max_per_ray_samples,
// const float dt
// );
std::vector<torch::Tensor> volumetric_rendering_inference(
torch::Tensor packed_info,
......@@ -69,6 +69,19 @@ torch::Tensor volumetric_weights_backward(
torch::Tensor sigmas
);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const float dt
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
......
......@@ -61,7 +61,8 @@ inline __device__ float advance_to_next_voxel(
}
__global__ void kernel_raymarching(
__global__ void marching_steps_kernel(
// rays info
const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3)
......@@ -75,18 +76,9 @@ __global__ void kernel_raymarching(
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt,
// writable helpers
int* steps_counter,
int* rays_counter,
// frustrum outputs
int* packed_info,
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts,
float* frustum_ends
// outputs
int* num_steps
) {
CUDA_GET_THREAD_ID(i, n_rays);
......@@ -95,23 +87,19 @@ __global__ void kernel_raymarching(
rays_d += i * 3;
t_min += i;
t_max += i;
num_steps += i;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
uint32_t ray_idx, base, marching_samples;
uint32_t j;
float t0, t1, t_mid;
// first pass to compute an accurate number of steps
j = 0;
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
int j = 0;
float t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far && j < max_per_ray_samples) {
while (t_mid < far) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
......@@ -135,10 +123,47 @@ __global__ void kernel_raymarching(
}
if (j == 0) return;
marching_samples = j;
base = atomicAdd(steps_counter, marching_samples);
if (base + marching_samples > max_total_samples) return;
ray_idx = atomicAdd(rays_counter, 1);
num_steps[0] = j;
return;
}
__global__ void marching_forward_kernel(
// rays info
const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3)
const float* rays_d, // shape (n_rays, 3)
const float* t_min, // shape (n_rays,)
const float* t_max, // shape (n_rays,)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const float dt,
const int* packed_info,
// frustrum outputs
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts,
float* frustum_ends
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
// locate
frustum_origins += base * 3;
......@@ -146,13 +171,12 @@ __global__ void kernel_raymarching(
frustum_starts += base;
frustum_ends += base;
// Second round
j = 0;
t0 = near;
t1 = t0 + dt;
t_mid = (t0 + t1) / 2.;
int j = 0;
float t0 = near;
float t1 = t0 + dt;
float t_mid = (t0 + t1) / 2.;
while (t_mid < far && j < marching_samples) {
while (t_mid < far) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
......@@ -182,43 +206,13 @@ __global__ void kernel_raymarching(
t1 = t_mid + dt * 0.5f;
}
}
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
packed_info[ray_idx * 3 + 1] = base; // point idx start.
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
if (j != steps) {
printf("WTF %d v.s. %d\n", j, steps);
}
return;
}
/**
* @brief Sample points by ray marching.
*
* @param rays_o Ray origins Shape of [n_rays, 3].
* @param rays_d Normalized ray directions. Shape of [n_rays, 3].
* @param t_min Near planes of rays. Shape of [n_rays].
* @param t_max Far planes of rays. Shape of [n_rays].
* @param grid_center Density grid center. TODO: support 3-dims.
* @param grid_scale Density grid base level scale. TODO: support 3-dims.
* @param grid_cascades Density grid levels.
* @param grid_size Density grid resolution.
* @param grid_bitfield Density grid uint8 bit field.
* @param marching_steps Marching steps during inference.
* @param max_total_samples Maximum total number of samples in this batch.
* @param max_ray_samples Used to define the minimal step size: SQRT3() / max_ray_samples.
* @param cone_angle 0. for nerf-synthetic and 1./256 for real scenes.
* @param step_scale Scale up the step size by this much. Usually equals to scene scale.
* @return std::vector<torch::Tensor>
* - packed_info: Stores how to index the ray samples from the returned values.
* Shape of [n_rays, 3]. First value is the ray index. Second value is the sample
* start index in the results for this ray. Third value is the number of samples for
* this ray. Note for rays that have zero samples, we simply skip them so the `packed_info`
* has some zero padding in the end.
* - origins: Ray origins for those samples. [max_total_samples, 3]
* - dirs: Ray directions for those samples. [max_total_samples, 3]
* - starts: Where the frustum-shape sample starts along a ray. [max_total_samples, 1]
* - ends: Where the frustum-shape sample ends along a ray. [max_total_samples, 1]
*/
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
......@@ -230,8 +224,6 @@ std::vector<torch::Tensor> ray_marching(
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt
) {
DEVICE_GUARD(rays_o);
......@@ -249,20 +241,43 @@ std::vector<torch::Tensor> ray_marching(
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor steps_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
torch::Tensor rays_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
torch::Tensor num_steps = torch::zeros(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
marching_steps_kernel<<<blocks, threads>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// sampling
dt,
// writable helpers
num_steps.data_ptr<int>()
);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// std::cout << "num_steps" << num_steps.dtype() << std::endl;
// std::cout << "cum_steps" << cum_steps.dtype() << std::endl;
// std::cout << "packed_info" << packed_info.dtype() << std::endl;
// output frustum samples
torch::Tensor packed_info = torch::zeros(
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor frustum_origins = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options());
kernel_raymarching<<<blocks, threads>>>(
marching_forward_kernel<<<blocks, threads>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
......@@ -276,19 +291,15 @@ std::vector<torch::Tensor> ray_marching(
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// sampling
max_total_samples,
max_per_ray_samples,
dt,
// writable helpers
steps_counter.data_ptr<int>(), // total samples.
rays_counter.data_ptr<int>(), // total rays.
packed_info.data_ptr<int>(),
packed_info.data_ptr<int>(),
// outputs
frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>()
);
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends};
}
......@@ -16,10 +16,9 @@ __global__ void volumetric_rendering_inference_kernel(
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
const int base = packed_info[thread_id * 2 + 0]; // point idx start.
const int steps = packed_info[thread_id * 2 + 1]; // point idx shift.
if (steps == 0) return;
starts += base;
ends += base;
......@@ -29,7 +28,7 @@ __global__ void volumetric_rendering_inference_kernel(
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
int j = 0;
for (; j < numsteps; ++j) {
for (; j < steps; ++j) {
if (T < EPSILON) {
break;
}
......@@ -46,10 +45,8 @@ __global__ void volumetric_rendering_inference_kernel(
compact_selector[k] = base + k;
}
compact_packed_info += thread_id * 3;
compact_packed_info[0] = i; // ray idx in {rays_o, rays_d}
compact_packed_info[1] = compact_base; // compact point idx start.
compact_packed_info[2] = j; // compact point idx shift.
compact_packed_info[thread_id * 2 + 0] = compact_base; // compact point idx start.
compact_packed_info[thread_id * 2 + 1] = j; // compact point idx shift.
}
......@@ -201,7 +198,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
......@@ -217,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
{1}, packed_info.options().dtype(torch::kInt32));
// outputs
torch::Tensor compact_packed_info = torch::zeros({n_rays, 3}, packed_info.options());
torch::Tensor compact_packed_info = torch::zeros({n_rays, 2}, packed_info.options());
torch::Tensor compact_selector = - torch::ones({n_samples}, packed_info.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
......
......@@ -13,13 +13,12 @@ __global__ void volumetric_weights_forward_kernel(
int* samples_ray_ids, // output
bool* mask // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
CUDA_GET_THREAD_ID(i, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
starts += base;
ends += base;
......@@ -28,14 +27,14 @@ __global__ void volumetric_weights_forward_kernel(
samples_ray_ids += base;
mask += i;
for (int j = 0; j < numsteps; ++j) {
for (int j = 0; j < steps; ++j) {
samples_ray_ids[j] = i;
}
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
for (int j = 0; j < steps; ++j) {
if (T < EPSILON) {
break;
}
......@@ -60,13 +59,12 @@ __global__ void volumetric_weights_backward_kernel(
const scalar_t* grad_weights, // input
scalar_t* grad_sigmas // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
CUDA_GET_THREAD_ID(i, n_rays);
// locate
// const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
const int base = packed_info[i * 2 + 0]; // point idx start.
const int steps = packed_info[i * 2 + 1]; // point idx shift.
if (steps == 0) return;
starts += base;
ends += base;
......@@ -76,14 +74,14 @@ __global__ void volumetric_weights_backward_kernel(
grad_sigmas += base;
scalar_t accum = 0;
for (int j = 0; j < numsteps; ++j) {
for (int j = 0; j < steps; ++j) {
accum += grad_weights[j] * weights[j];
}
// backward of accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
for (int j = 0; j < steps; ++j) {
if (T < EPSILON) {
break;
}
......@@ -108,7 +106,7 @@ std::vector<torch::Tensor> volumetric_weights_forward(
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
......
......@@ -54,7 +54,6 @@ def volumetric_rendering(
frustum_dirs,
frustum_starts,
frustum_ends,
steps_counter,
) = ray_marching(
# rays
rays_o,
......@@ -66,22 +65,13 @@ def volumetric_rendering(
scene_resolution,
scene_occ_binary,
# sampling
render_total_samples,
render_n_samples,
render_step_size,
)
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
total_samples = int(math.ceil(total_samples / 256.0)) * 256
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
steps_counter = packed_info[:, -1].sum(0, keepdim=True)
with torch.no_grad():
densities = query_fn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment