Commit 86b90ea6 authored by Ruilong Li's avatar Ruilong Li
Browse files

cleanup

parent 99562b18
......@@ -3,13 +3,13 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from ._backend import _C
ray_aabb_intersect = _C.ray_aabb_intersect
# ray_aabb_intersect = _C.ray_aabb_intersect
ray_marching = _C.ray_marching
volumetric_rendering_forward = _C.volumetric_rendering_forward
volumetric_rendering_backward = _C.volumetric_rendering_backward
volumetric_rendering_inference = _C.volumetric_rendering_inference
compute_weights_forward = _C.compute_weights_forward
compute_weights_backward = _C.compute_weights_backward
# volumetric_weights_forward = _C.volumetric_weights_forward
# volumetric_weights_forward = _C.volumetric_weights_forward
class VolumeRenderer(torch.autograd.Function):
......@@ -73,44 +73,3 @@ class VolumeRenderer(torch.autograd.Function):
)
# corresponds to the input argument list of forward()
return None, None, None, grad_sigmas, grad_rgbs
class ComputeWeight(torch.autograd.Function):
"""CUDA Compute Weight"""
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, packed_info, starts, ends, sigmas):
(
weights,
ray_indices,
mask,
) = compute_weights_forward(packed_info, starts, ends, sigmas)
ctx.save_for_backward(
packed_info,
starts,
ends,
sigmas,
weights,
)
return weights, ray_indices, mask
@staticmethod
@custom_bwd
def backward(ctx, grad_weights, _grad_ray_indices, _grad_mask):
(
packed_info,
starts,
ends,
sigmas,
weights,
) = ctx.saved_tensors
grad_sigmas = compute_weights_backward(
weights,
grad_weights,
packed_info,
starts,
ends,
sigmas,
)
return None, None, None, grad_sigmas
......@@ -53,14 +53,14 @@ std::vector<torch::Tensor> volumetric_rendering_backward(
torch::Tensor rgbs
);
std::vector<torch::Tensor> compute_weights_forward(
std::vector<torch::Tensor> volumetric_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
);
torch::Tensor compute_weights_backward(
torch::Tensor volumetric_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
......@@ -77,6 +77,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("volumetric_rendering_inference", &volumetric_rendering_inference);
m.def("volumetric_rendering_forward", &volumetric_rendering_forward);
m.def("volumetric_rendering_backward", &volumetric_rendering_backward);
m.def("compute_weights_forward", &compute_weights_forward);
m.def("compute_weights_backward", &compute_weights_backward);
m.def("volumetric_weights_forward", &volumetric_weights_forward);
m.def("volumetric_weights_backward", &volumetric_weights_backward);
}
\ No newline at end of file
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void compute_weights_forward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
// should be all-zero initialized
scalar_t* weights, // output
int* samples_ray_ids, // output
bool* mask // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
samples_ray_ids += base;
mask += i;
for (int j = 0; j < numsteps; ++j) {
samples_ray_ids[j] = i;
}
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
weights[j] = weight;
T *= (1.f - alpha);
}
mask[0] = true;
}
template <typename scalar_t>
__global__ void compute_weights_backward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
const scalar_t* weights, // forward output
const scalar_t* grad_weights, // input
scalar_t* grad_sigmas // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
// const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
grad_weights += base;
grad_sigmas += base;
scalar_t accum = 0;
for (int j = 0; j < numsteps; ++j) {
accum += grad_weights[j] * weights[j];
}
// backward of accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = delta * (grad_weights[j] * T - accum);
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
}
template <typename scalar_t>
__global__ void volumetric_rendering_inference_kernel(
const uint32_t n_rays,
......@@ -465,89 +369,3 @@ std::vector<torch::Tensor> volumetric_rendering_backward(
return {grad_sigmas, grad_rgbs};
}
\ No newline at end of file
std::vector<torch::Tensor> compute_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
// The rays that are not skipped during sampling.
torch::Tensor mask = torch::zeros({n_rays}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"compute_weights_forward",
([&]
{ compute_weights_forward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
ray_indices.data_ptr<int>(),
mask.data_ptr<bool>()
);
}));
return {weights, ray_indices, mask};
}
torch::Tensor compute_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"compute_weights_backward",
([&]
{ compute_weights_backward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
grad_sigmas.data_ptr<scalar_t>()
);
}));
return grad_sigmas;
}
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void volumetric_weights_forward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
// should be all-zero initialized
scalar_t* weights, // output
int* samples_ray_ids, // output
bool* mask // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
samples_ray_ids += base;
mask += i;
for (int j = 0; j < numsteps; ++j) {
samples_ray_ids[j] = i;
}
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
weights[j] = weight;
T *= (1.f - alpha);
}
mask[0] = true;
}
template <typename scalar_t>
__global__ void volumetric_weights_backward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
const scalar_t* weights, // forward output
const scalar_t* grad_weights, // input
scalar_t* grad_sigmas // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
// const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
weights += base;
grad_weights += base;
grad_sigmas += base;
scalar_t accum = 0;
for (int j = 0; j < numsteps; ++j) {
accum += grad_weights[j] * weights[j];
}
// backward of accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
for (int j = 0; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
grad_sigmas[j] = delta * (grad_weights[j] * T - accum);
accum -= grad_weights[j] * weights[j];
T *= (1.f - alpha);
}
}
std::vector<torch::Tensor> volumetric_weights_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor weights = torch::zeros({n_samples}, sigmas.options());
torch::Tensor ray_indices = torch::zeros({n_samples}, packed_info.options());
// The rays that are not skipped during sampling.
torch::Tensor mask = torch::zeros({n_rays}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_weights_forward",
([&]
{ volumetric_weights_forward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
ray_indices.data_ptr<int>(),
mask.data_ptr<bool>()
);
}));
return {weights, ray_indices, mask};
}
torch::Tensor volumetric_weights_backward(
torch::Tensor weights,
torch::Tensor grad_weights,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas
) {
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const uint32_t n_samples = sigmas.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_weights_backward",
([&]
{ volumetric_weights_backward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
weights.data_ptr<scalar_t>(),
grad_weights.data_ptr<scalar_t>(),
grad_sigmas.data_ptr<scalar_t>()
);
}));
return grad_sigmas;
}
from typing import Tuple
import torch
from .cuda import _C
def ray_aabb_intersect(
rays_o: torch.Tensor, rays_d: torch.Tensor, aabb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Ray AABB Test.
Note: this function is not differentiable to inputs.
Args:
rays_o: Ray origins. Tensor with shape (n_rays, 3).
rays_d: Normalized ray directions. Tensor with shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
Tensor with shape (6)
Returns:
Ray AABB intersection {t_min, t_max} with shape (n_rays) respectively.
Note the t_min is clipped to minimum zero. 1e10 means no intersection.
"""
if rays_o.is_cuda and rays_d.is_cuda and aabb.is_cuda:
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
aabb = aabb.contiguous()
t_min, t_max = _C.ray_aabb_intersect(rays_o, rays_d, aabb)
else:
raise NotImplementedError("Only support cuda inputs.")
return t_min, t_max
def volumetric_weights(
packed_info: torch.Tensor,
t_starts: torch.Tensor,
t_ends: torch.Tensor,
sigmas: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute weights for volumetric rendering.
Note: this function is only differentiable to `sigmas`.
Args:
packed_info: Stores infomation on which samples belong to the same ray.
See `volumetric_sampling` for details. Tensor with shape (n_rays, 3).
t_starts: Where the frustum-shape sample starts along a ray. Tensor with
shape (n_samples, 1).
t_ends: Where the frustum-shape sample ends along a ray. Tensor with
shape (n_samples, 1).
sigmas: Densities at those samples. Tensor with
shape (n_samples, 1).
Returns:
weights: Volumetric rendering weights for those samples. Tensor with shape
(n_samples).
ray_indices: Ray index of each sample. IntTensor with shape (n_sample).
ray_alive_masks: Whether we skipped this ray during sampling. BoolTensor with
shape (n_rays)
"""
if packed_info.is_cuda and t_starts.is_cuda and t_ends.is_cuda and sigmas.is_cuda:
packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous()
weights, ray_indices, ray_alive_masks = _volumetric_weights.apply(
packed_info, t_starts, t_ends, sigmas
)
else:
raise NotImplementedError("Only support cuda inputs.")
return weights, ray_indices, ray_alive_masks
def volumetric_accumulate(
weights: torch.Tensor,
ray_indices: torch.Tensor,
values: torch.Tensor = None,
n_rays: int = None,
) -> torch.Tensor:
"""Accumulate values along the ray.
Note: this function is only differentiable to `weights` and `values`.
Args:
weights: Volumetric rendering weights for those samples. Tensor with shape
(n_samples).
ray_indices: Ray index of each sample. IntTensor with shape (n_sample).
values: The values to be accmulated. Tensor with shape (n_samples, D). If
None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If
None, it will be inferred from `ray_indices.max() + 1`. If specified
it should be at least larger than `ray_indices.max()`. Default is None.
Returns:
Accumulated values with shape (n_rays, D). If `values` is not given then
we return the accumulated weights, in which case D == 1.
"""
assert ray_indices.dim() == 1 and weights.dim() == 1
if not weights.is_cuda:
raise NotImplementedError("Only support cuda inputs.")
if values is not None:
assert values.dim() == 2 and values.shape[0] == weights.shape[0]
src = weights[:, None] * values
else:
src = weights[:, None]
if ray_indices.numel() == 0:
assert n_rays is not None
return torch.zeros((n_rays, src.shape[-1]), device=weights.device)
if n_rays is None:
n_rays = ray_indices.max() + 1
else:
assert n_rays > ray_indices.max()
index = ray_indices[:, None].long().expand(-1, src.shape[-1])
outputs = torch.zeros((n_rays, src.shape[-1]), device=weights.device)
outputs.scatter_add_(0, index, src)
return outputs
class _volumetric_weights(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas):
(
weights,
ray_indices,
ray_alive_masks,
) = _C.volumetric_weights_forward(packed_info, t_starts, t_ends, sigmas)
ctx.save_for_backward(
packed_info,
t_starts,
t_ends,
sigmas,
weights,
)
return weights, ray_indices, ray_alive_masks
@staticmethod
def backward(ctx, grad_weights, _grad_ray_indices, _grad_ray_alive_masks):
(
packed_info,
t_starts,
t_ends,
sigmas,
weights,
) = ctx.saved_tensors
grad_sigmas = _C.volumetric_weights_backward(
weights,
grad_weights,
packed_info,
t_starts,
t_ends,
sigmas,
)
return None, None, None, grad_sigmas
......@@ -3,13 +3,11 @@ from typing import Callable, Tuple
import torch
from .cuda import (
ComputeWeight,
VolumeRenderer,
ray_aabb_intersect,
from .cuda import ( # ComputeWeight,; VolumeRenderer,; ray_aabb_intersect,
ray_marching,
volumetric_rendering_inference,
)
from .utils import ray_aabb_intersect, volumetric_accumulate, volumetric_weights
def volumetric_rendering(
......@@ -129,35 +127,48 @@ def volumetric_rendering(
# compact_rgbs.contiguous(),
# )
compact_weights, compact_ray_indices, alive_ray_mask = ComputeWeight.apply(
compact_packed_info.contiguous(),
compact_frustum_starts.contiguous(),
compact_frustum_ends.contiguous(),
compact_densities.contiguous(),
compact_weights, compact_ray_indices, alive_ray_mask = volumetric_weights(
compact_packed_info,
compact_frustum_starts,
compact_frustum_ends,
compact_densities,
)
index = compact_ray_indices[:, None].long()
accumulated_color = torch.zeros((n_rays, 3), device=device)
accumulated_color.scatter_add_(
dim=0,
index=index.expand(-1, 3),
src=compact_weights[:, None] * compact_rgbs,
accumulated_color = volumetric_accumulate(
compact_weights, compact_ray_indices, compact_rgbs, n_rays
)
accumulated_weight = torch.zeros((n_rays, 1), device=device)
accumulated_weight.scatter_add_(
dim=0,
index=index.expand(-1, 1),
src=compact_weights[:, None],
accumulated_weight = volumetric_accumulate(
compact_weights, compact_ray_indices, None, n_rays
)
accumulated_depth = torch.zeros((n_rays, 1), device=device)
accumulated_depth.scatter_add_(
dim=0,
index=index.expand(-1, 1),
src=compact_weights[:, None]
* (compact_frustum_starts + compact_frustum_ends)
/ 2.0,
accumulated_depth = volumetric_accumulate(
compact_weights,
compact_ray_indices,
(compact_frustum_starts + compact_frustum_ends) / 2.0,
n_rays,
)
# index = compact_ray_indices[:, None].long()
# accumulated_color = torch.zeros((n_rays, 3), device=device)
# accumulated_color.scatter_add_(
# dim=0,
# index=index.expand(-1, 3),
# src=compact_weights[:, None] * compact_rgbs,
# )
# accumulated_weight = torch.zeros((n_rays, 1), device=device)
# accumulated_weight.scatter_add_(
# dim=0,
# index=index.expand(-1, 1),
# src=compact_weights[:, None],
# )
# accumulated_depth = torch.zeros((n_rays, 1), device=device)
# accumulated_depth.scatter_add_(
# dim=0,
# index=index.expand(-1, 1),
# src=compact_weights[:, None]
# * (compact_frustum_starts + compact_frustum_ends)
# / 2.0,
# )
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
# rgbs, densities = query_results[0], query_results[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment