Commit 1aeee0a9 authored by Ruilong Li's avatar Ruilong Li
Browse files

cleanup marching; resampling steps limit

parent ad2a0079
...@@ -265,7 +265,8 @@ std::vector<torch::Tensor> ray_resampling( ...@@ -265,7 +265,8 @@ std::vector<torch::Tensor> ray_resampling(
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1]; torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps; // torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_num_steps = torch::clamp(num_steps, 0, steps);
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32); torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat( torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1); {resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
......
...@@ -14,8 +14,6 @@ std::vector<torch::Tensor> ray_aabb_intersect( ...@@ -14,8 +14,6 @@ std::vector<torch::Tensor> ray_aabb_intersect(
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays // rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min, const torch::Tensor t_min,
const torch::Tensor t_max, const torch::Tensor t_max,
// sampling // sampling
......
...@@ -80,8 +80,6 @@ inline __device__ __host__ float advance_to_next_voxel( ...@@ -80,8 +80,6 @@ inline __device__ __host__ float advance_to_next_voxel(
__global__ void ray_marching_kernel( __global__ void ray_marching_kernel(
// rays info // rays info
const uint32_t n_rays, const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,) const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,) const float *t_max, // shape (n_rays,)
// sampling // sampling
...@@ -100,8 +98,6 @@ __global__ void ray_marching_kernel( ...@@ -100,8 +98,6 @@ __global__ void ray_marching_kernel(
bool is_first_round = (packed_info == nullptr); bool is_first_round = (packed_info == nullptr);
// locate // locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i; t_min += i;
t_max += i; t_max += i;
...@@ -118,9 +114,6 @@ __global__ void ray_marching_kernel( ...@@ -118,9 +114,6 @@ __global__ void ray_marching_kernel(
ray_indices += base; ray_indices += base;
} }
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0]; const float near = t_min[0], far = t_max[0];
float dt_min = step_size; float dt_min = step_size;
...@@ -270,39 +263,31 @@ __global__ void ray_marching_with_grid_kernel( ...@@ -270,39 +263,31 @@ __global__ void ray_marching_with_grid_kernel(
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays // rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min, const torch::Tensor t_min,
const torch::Tensor t_max, const torch::Tensor t_max,
// sampling // sampling
const float step_size, const float step_size,
const float cone_angle) const float cone_angle)
{ {
DEVICE_GUARD(rays_o); DEVICE_GUARD(t_min);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min); CHECK_INPUT(t_min);
CHECK_INPUT(t_max); CHECK_INPUT(t_max);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1) TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1) TORCH_CHECK(t_max.ndimension() == 1)
const int n_rays = rays_o.size(0); const int n_rays = t_min.size(0);
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter // helper counter
torch::Tensor num_steps = torch::empty( torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32)); {n_rays}, t_min.options().dtype(torch::kInt32));
// count number of samples per ray // count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(), t_min.data_ptr<float>(),
t_max.data_ptr<float>(), t_max.data_ptr<float>(),
// sampling // sampling
...@@ -320,15 +305,13 @@ std::vector<torch::Tensor> ray_marching( ...@@ -320,15 +305,13 @@ std::vector<torch::Tensor> ray_marching(
// output samples starts and ends // output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>(); int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_starts = torch::empty({total_steps, 1}, t_min.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_ends = torch::empty({total_steps, 1}, t_min.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options()); torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(), t_min.data_ptr<float>(),
t_max.data_ptr<float>(), t_max.data_ptr<float>(),
// sampling // sampling
......
...@@ -231,8 +231,6 @@ def ray_marching( ...@@ -231,8 +231,6 @@ def ray_marching(
# marching # marching
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching( packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays # rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(), t_min.contiguous(),
t_max.contiguous(), t_max.contiguous(),
# sampling # sampling
......
...@@ -499,6 +499,11 @@ def render_visibility( ...@@ -499,6 +499,11 @@ def render_visibility(
tensor([True, True, False, True, False, False, True]) tensor([True, True, False, True, False, False, True])
""" """
assert (
alphas.dim() == 2 and alphas.shape[-1] == 1
), "alphas should be a 2D tensor with shape (n_samples, 1)."
visibility = alphas >= alpha_thre
if early_stop_eps > 0:
assert ( assert (
ray_indices is not None or packed_info is not None ray_indices is not None or packed_info is not None
), "Either ray_indices or packed_info should be provided." ), "Either ray_indices or packed_info should be provided."
...@@ -512,9 +517,7 @@ def render_visibility( ...@@ -512,9 +517,7 @@ def render_visibility(
transmittance = _RenderingTransmittanceFromAlphaNaive.apply( transmittance = _RenderingTransmittanceFromAlphaNaive.apply(
packed_info, alphas packed_info, alphas
) )
visibility = transmittance >= early_stop_eps visibility = visibility & (transmittance >= early_stop_eps)
if alpha_thre > 0:
visibility = visibility & (alphas >= alpha_thre)
visibility = visibility.squeeze(-1) visibility = visibility.squeeze(-1)
return visibility return visibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment