Commit c3ab153e authored by Ruilong Li's avatar Ruilong Li
Browse files

cumsum for marching

parent 86b90ea6
...@@ -12,9 +12,9 @@ python examples/trainval.py ...@@ -12,9 +12,9 @@ python examples/trainval.py
| trainval (35k, 1<<16) | Lego | Mic | Materials | | trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - | | - | - | - | - |
| Time | 377s | 357s | 354s | | Time | 325s | 357s | 354s |
| PSNR | 36.08 | 36.58 | 29.63 | | PSNR | 36.20 | 36.55 | 29.63 |
| FPS | 12.56 | 25.54 |
Tested with the default settings on the Lego test set. Tested with the default settings on the Lego test set.
......
...@@ -90,14 +90,14 @@ if __name__ == "__main__": ...@@ -90,14 +90,14 @@ if __name__ == "__main__":
torch.manual_seed(42) torch.manual_seed(42)
device = "cuda:0" device = "cuda:0"
scene = "lego" scene = "materials"
# setup dataset # setup dataset
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=scene, subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval", split="trainval",
num_rays=4096, num_rays=1024,
) )
train_dataset.images = train_dataset.images.to(device) train_dataset.images = train_dataset.images.to(device)
......
...@@ -8,21 +8,21 @@ std::vector<torch::Tensor> ray_aabb_intersect( ...@@ -8,21 +8,21 @@ std::vector<torch::Tensor> ray_aabb_intersect(
); );
std::vector<torch::Tensor> ray_marching( // std::vector<torch::Tensor> ray_marching(
// rays // // rays
const torch::Tensor rays_o, // const torch::Tensor rays_o,
const torch::Tensor rays_d, // const torch::Tensor rays_d,
const torch::Tensor t_min, // const torch::Tensor t_min,
const torch::Tensor t_max, // const torch::Tensor t_max,
// density grid // // density grid
const torch::Tensor aabb, // const torch::Tensor aabb,
const pybind11::list resolution, // const pybind11::list resolution,
const torch::Tensor occ_binary, // const torch::Tensor occ_binary,
// sampling // // sampling
const int max_total_samples, // const int max_total_samples,
const int max_per_ray_samples, // const int max_per_ray_samples,
const float dt // const float dt
); // );
std::vector<torch::Tensor> volumetric_rendering_inference( std::vector<torch::Tensor> volumetric_rendering_inference(
torch::Tensor packed_info, torch::Tensor packed_info,
...@@ -69,6 +69,19 @@ torch::Tensor volumetric_weights_backward( ...@@ -69,6 +69,19 @@ torch::Tensor volumetric_weights_backward(
torch::Tensor sigmas torch::Tensor sigmas
); );
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const float dt
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
......
...@@ -61,7 +61,8 @@ inline __device__ float advance_to_next_voxel( ...@@ -61,7 +61,8 @@ inline __device__ float advance_to_next_voxel(
} }
__global__ void kernel_raymarching(
__global__ void marching_steps_kernel(
// rays info // rays info
const uint32_t n_rays, const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3) const float* rays_o, // shape (n_rays, 3)
...@@ -75,18 +76,9 @@ __global__ void kernel_raymarching( ...@@ -75,18 +76,9 @@ __global__ void kernel_raymarching(
const int resz, const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z) const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling // sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt, const float dt,
// writable helpers // outputs
int* steps_counter, int* num_steps
int* rays_counter,
// frustrum outputs
int* packed_info,
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts,
float* frustum_ends
) { ) {
CUDA_GET_THREAD_ID(i, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
...@@ -95,23 +87,19 @@ __global__ void kernel_raymarching( ...@@ -95,23 +87,19 @@ __global__ void kernel_raymarching(
rays_d += i * 3; rays_d += i * 3;
t_min += i; t_min += i;
t_max += i; t_max += i;
num_steps += i;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0]; const float near = t_min[0], far = t_max[0];
uint32_t ray_idx, base, marching_samples; int j = 0;
uint32_t j; float t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
float t0, t1, t_mid; float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
// first pass to compute an accurate number of steps
j = 0;
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
while (t_mid < far && j < max_per_ray_samples) { while (t_mid < far) {
// current center // current center
const float x = ox + t_mid * dx; const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy; const float y = oy + t_mid * dy;
...@@ -135,10 +123,47 @@ __global__ void kernel_raymarching( ...@@ -135,10 +123,47 @@ __global__ void kernel_raymarching(
} }
if (j == 0) return; if (j == 0) return;
marching_samples = j; num_steps[0] = j;
base = atomicAdd(steps_counter, marching_samples); return;
if (base + marching_samples > max_total_samples) return; }
ray_idx = atomicAdd(rays_counter, 1);
__global__ void marching_forward_kernel(
// rays info
const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3)
const float* rays_d, // shape (n_rays, 3)
const float* t_min, // shape (n_rays,)
const float* t_max, // shape (n_rays,)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const float dt,
const int* packed_info,
// frustrum outputs
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts,
float* frustum_ends
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
// locate // locate
frustum_origins += base * 3; frustum_origins += base * 3;
...@@ -146,13 +171,12 @@ __global__ void kernel_raymarching( ...@@ -146,13 +171,12 @@ __global__ void kernel_raymarching(
frustum_starts += base; frustum_starts += base;
frustum_ends += base; frustum_ends += base;
// Second round int j = 0;
j = 0; float t0 = near;
t0 = near; float t1 = t0 + dt;
t1 = t0 + dt; float t_mid = (t0 + t1) / 2.;
t_mid = (t0 + t1) / 2.;
while (t_mid < far && j < marching_samples) { while (t_mid < far) {
// current center // current center
const float x = ox + t_mid * dx; const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy; const float y = oy + t_mid * dy;
...@@ -182,43 +206,13 @@ __global__ void kernel_raymarching( ...@@ -182,43 +206,13 @@ __global__ void kernel_raymarching(
t1 = t_mid + dt * 0.5f; t1 = t_mid + dt * 0.5f;
} }
} }
if (j != steps) {
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d} printf("WTF %d v.s. %d\n", j, steps);
packed_info[ray_idx * 3 + 1] = base; // point idx start. }
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return; return;
} }
/**
* @brief Sample points by ray marching.
*
* @param rays_o Ray origins Shape of [n_rays, 3].
* @param rays_d Normalized ray directions. Shape of [n_rays, 3].
* @param t_min Near planes of rays. Shape of [n_rays].
* @param t_max Far planes of rays. Shape of [n_rays].
* @param grid_center Density grid center. TODO: support 3-dims.
* @param grid_scale Density grid base level scale. TODO: support 3-dims.
* @param grid_cascades Density grid levels.
* @param grid_size Density grid resolution.
* @param grid_bitfield Density grid uint8 bit field.
* @param marching_steps Marching steps during inference.
* @param max_total_samples Maximum total number of samples in this batch.
* @param max_ray_samples Used to define the minimal step size: SQRT3() / max_ray_samples.
* @param cone_angle 0. for nerf-synthetic and 1./256 for real scenes.
* @param step_scale Scale up the step size by this much. Usually equals to scene scale.
* @return std::vector<torch::Tensor>
* - packed_info: Stores how to index the ray samples from the returned values.
* Shape of [n_rays, 3]. First value is the ray index. Second value is the sample
* start index in the results for this ray. Third value is the number of samples for
* this ray. Note for rays that have zero samples, we simply skip them so the `packed_info`
* has some zero padding in the end.
* - origins: Ray origins for those samples. [max_total_samples, 3]
* - dirs: Ray directions for those samples. [max_total_samples, 3]
* - starts: Where the frustum-shape sample starts along a ray. [max_total_samples, 1]
* - ends: Where the frustum-shape sample ends along a ray. [max_total_samples, 1]
*/
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays // rays
const torch::Tensor rays_o, const torch::Tensor rays_o,
...@@ -230,8 +224,6 @@ std::vector<torch::Tensor> ray_marching( ...@@ -230,8 +224,6 @@ std::vector<torch::Tensor> ray_marching(
const pybind11::list resolution, const pybind11::list resolution,
const torch::Tensor occ_binary, const torch::Tensor occ_binary,
// sampling // sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt const float dt
) { ) {
DEVICE_GUARD(rays_o); DEVICE_GUARD(rays_o);
...@@ -249,20 +241,43 @@ std::vector<torch::Tensor> ray_marching( ...@@ -249,20 +241,43 @@ std::vector<torch::Tensor> ray_marching(
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter // helper counter
torch::Tensor steps_counter = torch::zeros( torch::Tensor num_steps = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32)); {n_rays}, rays_o.options().dtype(torch::kInt32));
torch::Tensor rays_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32)); // count number of samples per ray
marching_steps_kernel<<<blocks, threads>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// sampling
dt,
// writable helpers
num_steps.data_ptr<int>()
);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// std::cout << "num_steps" << num_steps.dtype() << std::endl;
// std::cout << "cum_steps" << cum_steps.dtype() << std::endl;
// std::cout << "packed_info" << packed_info.dtype() << std::endl;
// output frustum samples // output frustum samples
torch::Tensor packed_info = torch::zeros( int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples torch::Tensor frustum_origins = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options()); torch::Tensor frustum_dirs = torch::zeros({total_steps, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options()); torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options()); torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
kernel_raymarching<<<blocks, threads>>>( marching_forward_kernel<<<blocks, threads>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
...@@ -276,19 +291,15 @@ std::vector<torch::Tensor> ray_marching( ...@@ -276,19 +291,15 @@ std::vector<torch::Tensor> ray_marching(
resolution[2].cast<int>(), resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(), occ_binary.data_ptr<bool>(),
// sampling // sampling
max_total_samples,
max_per_ray_samples,
dt, dt,
// writable helpers packed_info.data_ptr<int>(),
steps_counter.data_ptr<int>(), // total samples. // outputs
rays_counter.data_ptr<int>(), // total rays.
packed_info.data_ptr<int>(),
frustum_origins.data_ptr<float>(), frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(), frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(), frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>() frustum_ends.data_ptr<float>()
); );
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter}; return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends};
} }
...@@ -16,10 +16,9 @@ __global__ void volumetric_rendering_inference_kernel( ...@@ -16,10 +16,9 @@ __global__ void volumetric_rendering_inference_kernel(
CUDA_GET_THREAD_ID(thread_id, n_rays); CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate // locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d} const int base = packed_info[thread_id * 2 + 0]; // point idx start.
const int base = packed_info[thread_id * 3 + 1]; // point idx start. const int steps = packed_info[thread_id * 2 + 1]; // point idx shift.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift. if (steps == 0) return;
if (numsteps == 0) return;
starts += base; starts += base;
ends += base; ends += base;
...@@ -29,7 +28,7 @@ __global__ void volumetric_rendering_inference_kernel( ...@@ -29,7 +28,7 @@ __global__ void volumetric_rendering_inference_kernel(
scalar_t T = 1.f; scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f; scalar_t EPSILON = 1e-4f;
int j = 0; int j = 0;
for (; j < numsteps; ++j) { for (; j < steps; ++j) {
if (T < EPSILON) { if (T < EPSILON) {
break; break;
} }
...@@ -46,10 +45,8 @@ __global__ void volumetric_rendering_inference_kernel( ...@@ -46,10 +45,8 @@ __global__ void volumetric_rendering_inference_kernel(
compact_selector[k] = base + k; compact_selector[k] = base + k;
} }
compact_packed_info += thread_id * 3; compact_packed_info[thread_id * 2 + 0] = compact_base; // compact point idx start.
compact_packed_info[0] = i; // ray idx in {rays_o, rays_d} compact_packed_info[thread_id * 2 + 1] = j; // compact point idx shift.
compact_packed_info[1] = compact_base; // compact point idx start.
compact_packed_info[2] = j; // compact point idx shift.
} }
...@@ -201,7 +198,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference( ...@@ -201,7 +198,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
CHECK_INPUT(starts); CHECK_INPUT(starts);
CHECK_INPUT(ends); CHECK_INPUT(ends);
CHECK_INPUT(sigmas); CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3); TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
...@@ -217,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference( ...@@ -217,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_inference(
{1}, packed_info.options().dtype(torch::kInt32)); {1}, packed_info.options().dtype(torch::kInt32));
// outputs // outputs
torch::Tensor compact_packed_info = torch::zeros({n_rays, 3}, packed_info.options()); torch::Tensor compact_packed_info = torch::zeros({n_rays, 2}, packed_info.options());
torch::Tensor compact_selector = - torch::ones({n_samples}, packed_info.options()); torch::Tensor compact_selector = - torch::ones({n_samples}, packed_info.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
......
...@@ -13,13 +13,12 @@ __global__ void volumetric_weights_forward_kernel( ...@@ -13,13 +13,12 @@ __global__ void volumetric_weights_forward_kernel(
int* samples_ray_ids, // output int* samples_ray_ids, // output
bool* mask // output bool* mask // output
) { ) {
CUDA_GET_THREAD_ID(thread_id, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
// locate // locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d} const int base = packed_info[i * 2 + 0]; // point idx start.
const int base = packed_info[thread_id * 3 + 1]; // point idx start. const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift. if (steps == 0) return;
if (numsteps == 0) return;
starts += base; starts += base;
ends += base; ends += base;
...@@ -28,14 +27,14 @@ __global__ void volumetric_weights_forward_kernel( ...@@ -28,14 +27,14 @@ __global__ void volumetric_weights_forward_kernel(
samples_ray_ids += base; samples_ray_ids += base;
mask += i; mask += i;
for (int j = 0; j < numsteps; ++j) { for (int j = 0; j < steps; ++j) {
samples_ray_ids[j] = i; samples_ray_ids[j] = i;
} }
// accumulated rendering // accumulated rendering
scalar_t T = 1.f; scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f; scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) { for (int j = 0; j < steps; ++j) {
if (T < EPSILON) { if (T < EPSILON) {
break; break;
} }
...@@ -60,13 +59,12 @@ __global__ void volumetric_weights_backward_kernel( ...@@ -60,13 +59,12 @@ __global__ void volumetric_weights_backward_kernel(
const scalar_t* grad_weights, // input const scalar_t* grad_weights, // input
scalar_t* grad_sigmas // output scalar_t* grad_sigmas // output
) { ) {
CUDA_GET_THREAD_ID(thread_id, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
// locate // locate
// const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d} const int base = packed_info[i * 2 + 0]; // point idx start.
const int base = packed_info[thread_id * 3 + 1]; // point idx start. const int steps = packed_info[i * 2 + 1]; // point idx shift.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift. if (steps == 0) return;
if (numsteps == 0) return;
starts += base; starts += base;
ends += base; ends += base;
...@@ -76,14 +74,14 @@ __global__ void volumetric_weights_backward_kernel( ...@@ -76,14 +74,14 @@ __global__ void volumetric_weights_backward_kernel(
grad_sigmas += base; grad_sigmas += base;
scalar_t accum = 0; scalar_t accum = 0;
for (int j = 0; j < numsteps; ++j) { for (int j = 0; j < steps; ++j) {
accum += grad_weights[j] * weights[j]; accum += grad_weights[j] * weights[j];
} }
// backward of accumulated rendering // backward of accumulated rendering
scalar_t T = 1.f; scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f; scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) { for (int j = 0; j < steps; ++j) {
if (T < EPSILON) { if (T < EPSILON) {
break; break;
} }
...@@ -108,7 +106,7 @@ std::vector<torch::Tensor> volumetric_weights_forward( ...@@ -108,7 +106,7 @@ std::vector<torch::Tensor> volumetric_weights_forward(
CHECK_INPUT(starts); CHECK_INPUT(starts);
CHECK_INPUT(ends); CHECK_INPUT(ends);
CHECK_INPUT(sigmas); CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3); TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 2);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1); TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1); TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1); TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
......
...@@ -54,7 +54,6 @@ def volumetric_rendering( ...@@ -54,7 +54,6 @@ def volumetric_rendering(
frustum_dirs, frustum_dirs,
frustum_starts, frustum_starts,
frustum_ends, frustum_ends,
steps_counter,
) = ray_marching( ) = ray_marching(
# rays # rays
rays_o, rays_o,
...@@ -66,22 +65,13 @@ def volumetric_rendering( ...@@ -66,22 +65,13 @@ def volumetric_rendering(
scene_resolution, scene_resolution,
scene_occ_binary, scene_occ_binary,
# sampling # sampling
render_total_samples,
render_n_samples,
render_step_size, render_step_size,
) )
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
total_samples = int(math.ceil(total_samples / 256.0)) * 256
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
frustum_positions = ( frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0 frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
) )
steps_counter = packed_info[:, -1].sum(0, keepdim=True)
with torch.no_grad(): with torch.no_grad():
densities = query_fn( densities = query_fn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment