Commit 99562b18 authored by Ruilong Li's avatar Ruilong Li
Browse files

compute weight

parent b524287a
......@@ -8,6 +8,8 @@ ray_marching = _C.ray_marching
volumetric_rendering_forward = _C.volumetric_rendering_forward
volumetric_rendering_backward = _C.volumetric_rendering_backward
volumetric_rendering_inference = _C.volumetric_rendering_inference
compute_weights_forward = _C.compute_weights_forward
compute_weights_backward = _C.compute_weights_backward
class VolumeRenderer(torch.autograd.Function):
......@@ -71,3 +73,44 @@ class VolumeRenderer(torch.autograd.Function):
)
# corresponds to the input argument list of forward()
return None, None, None, grad_sigmas, grad_rgbs
class ComputeWeight(torch.autograd.Function):
"""CUDA Compute Weight"""
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, packed_info, starts, ends, sigmas):
(
weights,
ray_indices,
mask,
) = compute_weights_forward(packed_info, starts, ends, sigmas)
ctx.save_for_backward(
packed_info,
starts,
ends,
sigmas,
weights,
)
return weights, ray_indices, mask
@staticmethod
@custom_bwd
def backward(ctx, grad_weights, _grad_ray_indices, _grad_mask):
(
packed_info,
starts,
ends,
sigmas,
weights,
) = ctx.saved_tensors
grad_sigmas = compute_weights_backward(
weights,
grad_weights,
packed_info,
starts,
ends,
sigmas,
)
return None, None, None, grad_sigmas
......@@ -53,6 +53,22 @@ std::vector<torch::Tensor> volumetric_rendering_backward(
torch::Tensor rgbs
);
std::vector<torch::Tensor> compute_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
);
torch::Tensor compute_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
......@@ -61,4 +77,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("volumetric_rendering_inference", &volumetric_rendering_inference);
m.def("volumetric_rendering_forward", &volumetric_rendering_forward);
m.def("volumetric_rendering_backward", &volumetric_rendering_backward);
m.def("compute_weights_forward", &compute_weights_forward);
m.def("compute_weights_backward", &compute_weights_backward);
}
\ No newline at end of file
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void compute_weights_forward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
// should be all-zero initialized
scalar_t* weights, // output
int* samples_ray_ids, // output
bool* mask // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
samples_ray_ids += base;
mask += i;
for (int j = 0; j < numsteps; ++j) {
samples_ray_ids[j] = i;
}
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
weights[j] = weight;
T *= (1.f - alpha);
}
mask[0] = true;
}
template <typename scalar_t>
__global__ void compute_weights_backward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
const scalar_t* weights, // forward output
const scalar_t* grad_weights, // input
scalar_t* grad_sigmas // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
// const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
grad_weights += base;
grad_sigmas += base;
scalar_t accum = 0;
for (int j = 0; j < numsteps; ++j) {
accum += grad_weights[j] * weights[j];
}
// backward of accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = delta * (grad_weights[j] * T - accum);
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
}
template <typename scalar_t>
__global__ void volumetric_rendering_inference_kernel(
const uint32_t n_rays,
......@@ -368,4 +464,90 @@ std::vector<torch::Tensor> volumetric_rendering_backward(
}));
return {grad_sigmas, grad_rgbs};
}
\ No newline at end of file
}
std::vector<torch::Tensor> compute_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
// The rays that are not skipped during sampling.
torch::Tensor mask = torch::zeros({n_rays}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"compute_weights_forward",
([&]
{ compute_weights_forward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
ray_indices.data_ptr<int>(),
mask.data_ptr<bool>()
);
}));
return {weights, ray_indices, mask};
}
torch::Tensor compute_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"compute_weights_backward",
([&]
{ compute_weights_backward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
grad_sigmas.data_ptr<scalar_t>()
);
}));
return grad_sigmas;
}
......@@ -4,6 +4,7 @@ from typing import Callable, Tuple
import torch
from .cuda import (
ComputeWeight,
VolumeRenderer,
ray_aabb_intersect,
ray_marching,
......@@ -114,18 +115,47 @@ def volumetric_rendering(
)
compact_rgbs, compact_densities = compact_query_results[0], compact_query_results[1]
(
accumulated_weight,
accumulated_depth,
accumulated_color,
alive_ray_mask,
compact_steps_counter,
) = VolumeRenderer.apply(
# (
# accumulated_weight,
# accumulated_depth,
# accumulated_color,
# alive_ray_mask,
# compact_steps_counter,
# ) = VolumeRenderer.apply(
# compact_packed_info.contiguous(),
# compact_frustum_starts.contiguous(),
# compact_frustum_ends.contiguous(),
# compact_densities.contiguous(),
# compact_rgbs.contiguous(),
# )
compact_weights, compact_ray_indices, alive_ray_mask = ComputeWeight.apply(
compact_packed_info.contiguous(),
compact_frustum_starts.contiguous(),
compact_frustum_ends.contiguous(),
compact_densities.contiguous(),
compact_rgbs.contiguous(),
)
index = compact_ray_indices[:, None].long()
accumulated_color = torch.zeros((n_rays, 3), device=device)
accumulated_color.scatter_add_(
dim=0,
index=index.expand(-1, 3),
src=compact_weights[:, None] * compact_rgbs,
)
accumulated_weight = torch.zeros((n_rays, 1), device=device)
accumulated_weight.scatter_add_(
dim=0,
index=index.expand(-1, 1),
src=compact_weights[:, None],
)
accumulated_depth = torch.zeros((n_rays, 1), device=device)
accumulated_depth.scatter_add_(
dim=0,
index=index.expand(-1, 1),
src=compact_weights[:, None]
* (compact_frustum_starts + compact_frustum_ends)
/ 2.0,
)
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment