Commit 1aeee0a9 authored by Ruilong Li's avatar Ruilong Li
Browse files

cleanup marching; resampling steps limit

parent ad2a0079
......@@ -265,7 +265,8 @@ std::vector<torch::Tensor> ray_resampling(
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
torch::Tensor num_steps = torch::split(packed_info, 1, 1)[1];
torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
// torch::Tensor resample_num_steps = (num_steps > 0).to(num_steps.options()) * steps;
torch::Tensor resample_num_steps = torch::clamp(num_steps, 0, steps);
torch::Tensor resample_cum_steps = resample_num_steps.cumsum(0, torch::kInt32);
torch::Tensor resample_packed_info = torch::cat(
{resample_cum_steps - resample_num_steps, resample_num_steps}, 1);
......
......@@ -14,8 +14,6 @@ std::vector<torch::Tensor> ray_aabb_intersect(
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
......
......@@ -80,8 +80,6 @@ inline __device__ __host__ float advance_to_next_voxel(
__global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// sampling
......@@ -100,8 +98,6 @@ __global__ void ray_marching_kernel(
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
......@@ -118,9 +114,6 @@ __global__ void ray_marching_kernel(
ray_indices += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
float dt_min = step_size;
......@@ -270,39 +263,31 @@ __global__ void ray_marching_with_grid_kernel(
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
DEVICE_GUARD(t_min);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
const int n_rays = rays_o.size(0);
const int n_rays = t_min.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32));
{n_rays}, t_min.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
......@@ -320,15 +305,13 @@ std::vector<torch::Tensor> ray_marching(
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_starts = torch::empty({total_steps, 1}, t_min.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, t_min.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
......
......@@ -231,8 +231,6 @@ def ray_marching(
# marching
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# sampling
......
......@@ -499,6 +499,11 @@ def render_visibility(
tensor([True, True, False, True, False, False, True])
"""
assert (
alphas.dim() == 2 and alphas.shape[-1] == 1
), "alphas should be a 2D tensor with shape (n_samples, 1)."
visibility = alphas >= alpha_thre
if early_stop_eps > 0:
assert (
ray_indices is not None or packed_info is not None
), "Either ray_indices or packed_info should be provided."
......@@ -512,9 +517,7 @@ def render_visibility(
transmittance = _RenderingTransmittanceFromAlphaNaive.apply(
packed_info, alphas
)
visibility = transmittance >= early_stop_eps
if alpha_thre > 0:
visibility = visibility & (alphas >= alpha_thre)
visibility = visibility & (transmittance >= early_stop_eps)
visibility = visibility.squeeze(-1)
return visibility
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment