Commit b9ce1c5b authored by Ruilong Li's avatar Ruilong Li
Browse files

cuda version of occ query

parent b6a5694e
......@@ -21,3 +21,4 @@ volumetric_rendering_weights_backward = _make_lazy_cuda(
"volumetric_rendering_weights_backward"
)
unpack_to_ray_indices = _make_lazy_cuda("unpack_to_ray_indices")
query_occ = _make_lazy_cuda("query_occ")
......@@ -44,6 +44,14 @@ std::vector<torch::Tensor> volumetric_marching(
const float dt
);
torch::Tensor query_occ(
const torch::Tensor samples,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary
);
torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
......@@ -54,4 +62,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("volumetric_rendering_weights_forward", &volumetric_rendering_weights_forward);
m.def("volumetric_rendering_weights_backward", &volumetric_rendering_weights_backward);
m.def("unpack_to_ray_indices", &unpack_to_ray_indices);
m.def("query_occ", &query_occ);
}
\ No newline at end of file
......@@ -24,6 +24,9 @@ inline __device__ bool grid_occupied_at(
const int resx, const int resy, const int resz,
const float* aabb, const bool* occ_binary
) {
if (x <= aabb[0] || x >= aabb[3] || y <= aabb[1] || y >= aabb[4] || z <= aabb[2] || z >= aabb[5]) {
return false;
}
int idx = cascaded_grid_idx_at(x, y, z, resx, resy, resz, aabb);
return occ_binary[idx];
}
......@@ -223,6 +226,33 @@ __global__ void ray_indices_kernel(
}
__global__ void occ_query_kernel(
// rays info
const uint32_t n_samples,
const float* samples, // shape (n_samples, 3)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// outputs
bool* occs
) {
CUDA_GET_THREAD_ID(i, n_samples);
// locate
samples += i * 3;
occs += i;
occs[0] = grid_occupied_at(
samples[0], samples[1], samples[2],
resx, resy, resz, aabb, occ_binary
);
return;
}
std::vector<torch::Tensor> volumetric_marching(
// rays
const torch::Tensor rays_o,
......@@ -331,3 +361,35 @@ torch::Tensor unpack_to_ray_indices(const torch::Tensor packed_info) {
}
torch::Tensor query_occ(
const torch::Tensor samples,
// density grid
const torch::Tensor aabb,
const pybind11::list resolution,
const torch::Tensor occ_binary
) {
DEVICE_GUARD(samples);
CHECK_INPUT(samples);
const int n_samples = samples.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_samples, threads);
torch::Tensor occs = torch::zeros(
{n_samples}, samples.options().dtype(torch::kBool));
occ_query_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_samples,
samples.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// outputs
occs.data_ptr<bool>()
);
return occs;
}
......@@ -56,10 +56,11 @@ def volumetric_rendering_pipeline(
if scene_occ_binary is None:
scene_occ_binary = torch.ones(
(1, 1, 1),
(1),
dtype=torch.bool,
device=rays_o.device,
)
scene_resolution = [1, 1, 1]
if scene_resolution is None:
assert scene_occ_binary is not None and scene_occ_binary.dim() == 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment