"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "562a1c878e925a562824ded477942d10572f3a29"
Unverified Commit eecfd44b authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Reduce precision conversion when packing (#124)

* Reduce conversion. Assume that ray_indices should always be long and pack_info always int

* Remove some more casting

* More casting fixes

* Remove unecessary long casting

* Fix changes in precision at python level, ray_pdf_query test is already broken since it was reverted

* Support long when rendering transmittance

* Support long

* Woops
parent fe75cb86
...@@ -58,7 +58,6 @@ def render_image( ...@@ -58,7 +58,6 @@ def render_image(
num_rays, _ = rays_shape num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices, net=None): def sigma_fn(t_starts, t_ends, ray_indices, net=None):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices] t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
...@@ -68,7 +67,6 @@ def render_image( ...@@ -68,7 +67,6 @@ def render_image(
return radiance_field.query_density(positions) return radiance_field.query_density(positions)
def rgb_sigma_fn(t_starts, t_ends, ray_indices): def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices] t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
......
...@@ -48,7 +48,6 @@ def render_image( ...@@ -48,7 +48,6 @@ def render_image(
num_rays, _ = rays_shape num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices): def sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices] t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
...@@ -63,7 +62,6 @@ def render_image( ...@@ -63,7 +62,6 @@ def render_image(
return radiance_field.query_density(positions) return radiance_field.query_density(positions)
def rgb_sigma_fn(t_starts, t_ends, ray_indices): def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices] t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices] t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
......
...@@ -9,7 +9,7 @@ __global__ void unpack_info_kernel( ...@@ -9,7 +9,7 @@ __global__ void unpack_info_kernel(
const int n_rays, const int n_rays,
const int *packed_info, const int *packed_info,
// output // output
int *ray_indices) long *ray_indices)
{ {
CUDA_GET_THREAD_ID(i, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
...@@ -92,12 +92,12 @@ torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples) ...@@ -92,12 +92,12 @@ torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>(); // int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty( torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kInt32)); {n_samples}, packed_info.options().dtype(torch::kLong));
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays, n_rays,
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>()); ray_indices.data_ptr<long>());
return ray_indices; return ray_indices;
} }
......
...@@ -95,7 +95,7 @@ __global__ void ray_marching_kernel( ...@@ -95,7 +95,7 @@ __global__ void ray_marching_kernel(
// first round outputs // first round outputs
int *num_steps, int *num_steps,
// second round outputs // second round outputs
int *ray_indices, long *ray_indices,
float *t_starts, float *t_starts,
float *t_ends) float *t_ends)
{ {
...@@ -259,7 +259,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -259,7 +259,7 @@ std::vector<torch::Tensor> ray_marching(
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>(); int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options()); torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options().dtype(torch::kLong));
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
...@@ -279,7 +279,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -279,7 +279,7 @@ std::vector<torch::Tensor> ray_marching(
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
// outputs // outputs
nullptr, /* num_steps */ nullptr, /* num_steps */
ray_indices.data_ptr<int>(), ray_indices.data_ptr<long>(),
t_starts.data_ptr<float>(), t_starts.data_ptr<float>(),
t_ends.data_ptr<float>()); t_ends.data_ptr<float>());
......
...@@ -20,8 +20,8 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V ...@@ -20,8 +20,8 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void exclusive_sum_by_key( inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{ {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub ExclusiveSumByKey does not support more than INT_MAX elements"); "cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output, CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream()); num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
} }
...@@ -30,8 +30,8 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V ...@@ -30,8 +30,8 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void exclusive_prod_by_key( inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{ {
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(), TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub ExclusiveScanByKey does not support more than INT_MAX elements"); "cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f, CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream()); num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
} }
...@@ -60,7 +60,7 @@ torch::Tensor transmittance_from_sigma_forward_cub( ...@@ -60,7 +60,7 @@ torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas); torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas);
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key( exclusive_sum_by_key(
ray_indices.data_ptr<int>(), ray_indices.data_ptr<long>(),
sigmas_dt.data_ptr<float>(), sigmas_dt.data_ptr<float>(),
sigmas_dt_cumsum.data_ptr<float>(), sigmas_dt_cumsum.data_ptr<float>(),
n_samples); n_samples);
...@@ -97,7 +97,7 @@ torch::Tensor transmittance_from_sigma_backward_cub( ...@@ -97,7 +97,7 @@ torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad); torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key( exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int>() + n_samples), thrust::make_reverse_iterator(ray_indices.data_ptr<long>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples), thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples), thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples); n_samples);
...@@ -123,7 +123,7 @@ torch::Tensor transmittance_from_alpha_forward_cub( ...@@ -123,7 +123,7 @@ torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor transmittance = torch::empty_like(alphas); torch::Tensor transmittance = torch::empty_like(alphas);
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key( exclusive_prod_by_key(
ray_indices.data_ptr<int>(), ray_indices.data_ptr<long>(),
(1.0f - alphas).data_ptr<float>(), (1.0f - alphas).data_ptr<float>(),
transmittance.data_ptr<float>(), transmittance.data_ptr<float>(),
n_samples); n_samples);
...@@ -154,7 +154,7 @@ torch::Tensor transmittance_from_alpha_backward_cub( ...@@ -154,7 +154,7 @@ torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad); torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key( exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int>() + n_samples), thrust::make_reverse_iterator(ray_indices.data_ptr<long>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples), thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples), thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples); n_samples);
......
...@@ -300,18 +300,14 @@ def _meshgrid3d( ...@@ -300,18 +300,14 @@ def _meshgrid3d(
"""Create 3D grid coordinates.""" """Create 3D grid coordinates."""
assert len(res) == 3 assert len(res) == 3
res = res.tolist() res = res.tolist()
return ( return torch.stack(
torch.stack( torch.meshgrid(
torch.meshgrid( [
[ torch.arange(res[0], dtype=torch.long),
torch.arange(res[0]), torch.arange(res[1], dtype=torch.long),
torch.arange(res[1]), torch.arange(res[2], dtype=torch.long),
torch.arange(res[2]), ],
], indexing="ij",
indexing="ij", ),
), dim=-1,
dim=-1, ).to(device)
)
.long()
.to(device)
)
...@@ -37,8 +37,8 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: ...@@ -37,8 +37,8 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
), "mask must be with shape of (n_rays, n_samples)." ), "mask must be with shape of (n_rays, n_samples)."
assert mask.dtype == torch.bool, "mask must be a boolean tensor." assert mask.dtype == torch.bool, "mask must be a boolean tensor."
packed_data = data[mask] packed_data = data[mask]
num_steps = mask.long().sum(dim=-1) num_steps = mask.sum(dim=-1, dtype=torch.int32)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.long) cum_steps = num_steps.cumsum(dim=0, dtype=torch.int32)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1) packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
return packed_data, packed_info return packed_data, packed_info
...@@ -55,26 +55,26 @@ def pack_info(ray_indices: Tensor, n_rays: int = None) -> Tensor: ...@@ -55,26 +55,26 @@ def pack_info(ray_indices: Tensor, n_rays: int = None) -> Tensor:
Returns: Returns:
packed_info: Stores information on which samples belong to the same ray. \ packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
""" """
assert ( assert (
ray_indices.dim() == 1 ray_indices.dim() == 1
), "ray_indices must be a 1D tensor with shape (n_samples)." ), "ray_indices must be a 1D tensor with shape (n_samples)."
if ray_indices.is_cuda: if ray_indices.is_cuda:
ray_indices = ray_indices.contiguous().int() ray_indices = ray_indices
device = ray_indices.device device = ray_indices.device
if n_rays is None: if n_rays is None:
n_rays = int(ray_indices.max()) + 1 n_rays = int(ray_indices.max()) + 1
# else: # else:
# assert n_rays > ray_indices.max() # assert n_rays > ray_indices.max()
src = torch.ones_like(ray_indices) src = torch.ones_like(ray_indices, dtype=torch.int)
num_steps = torch.zeros((n_rays,), device=device, dtype=torch.int) num_steps = torch.zeros((n_rays,), device=device, dtype=torch.int)
num_steps.scatter_add_(0, ray_indices.long(), src) num_steps.scatter_add_(0, ray_indices, src)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int) cum_steps = num_steps.cumsum(dim=0, dtype=torch.int)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1) packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return packed_info.int() return packed_info
@torch.no_grad() @torch.no_grad()
...@@ -86,7 +86,7 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor: ...@@ -86,7 +86,7 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
Args: Args:
packed_info: Stores information on which samples belong to the same ray. \ packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
n_samples: Total number of samples. n_samples: Total number of samples.
Returns: Returns:
...@@ -115,10 +115,10 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor: ...@@ -115,10 +115,10 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
packed_info.dim() == 2 and packed_info.shape[-1] == 2 packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)." ), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda: if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous().int(), n_samples) ray_indices = _C.unpack_info(packed_info.contiguous(), n_samples)
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return ray_indices.long() return ray_indices
def unpack_data( def unpack_data(
...@@ -173,7 +173,7 @@ class _UnpackData(torch.autograd.Function): ...@@ -173,7 +173,7 @@ class _UnpackData(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info: Tensor, data: Tensor, n_samples: int): def forward(ctx, packed_info: Tensor, data: Tensor, n_samples: int):
# shape of the data should be (all_samples, D) # shape of the data should be (all_samples, D)
packed_info = packed_info.contiguous().int() packed_info = packed_info.contiguous()
data = data.contiguous() data = data.contiguous()
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
ctx.save_for_backward(packed_info) ctx.save_for_backward(packed_info)
......
...@@ -193,13 +193,13 @@ def ray_marching( ...@@ -193,13 +193,13 @@ def ray_marching(
if sigma_fn is not None or alpha_fn is not None: if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients # Query sigma without gradients
if sigma_fn is not None: if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long()) sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert ( assert (
sigmas.shape == t_starts.shape sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape) ), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None: elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long()) alphas = alpha_fn(t_starts, t_ends, ray_indices)
assert ( assert (
alphas.shape == t_starts.shape alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape) ), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
......
...@@ -127,7 +127,7 @@ def proposal_sampling_with_filter( ...@@ -127,7 +127,7 @@ def proposal_sampling_with_filter(
for proposal_fn, n_samples in zip(proposal_sigma_fns, proposal_n_samples): for proposal_fn, n_samples in zip(proposal_sigma_fns, proposal_n_samples):
# compute weights for resampling # compute weights for resampling
sigmas = proposal_fn(t_starts, t_ends, ray_indices.long()) sigmas = proposal_fn(t_starts, t_ends, ray_indices)
assert ( assert (
sigmas.shape == t_starts.shape sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape) ), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
...@@ -152,7 +152,7 @@ def proposal_sampling_with_filter( ...@@ -152,7 +152,7 @@ def proposal_sampling_with_filter(
# Rerun the proposal function **with** gradients on filtered samples. # Rerun the proposal function **with** gradients on filtered samples.
if proposal_require_grads: if proposal_require_grads:
with torch.enable_grad(): with torch.enable_grad():
sigmas = proposal_fn(t_starts, t_ends, ray_indices.long()) sigmas = proposal_fn(t_starts, t_ends, ray_indices)
weights = render_weight_from_density( weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices t_starts, t_ends, sigmas, ray_indices=ray_indices
) )
...@@ -168,7 +168,7 @@ def proposal_sampling_with_filter( ...@@ -168,7 +168,7 @@ def proposal_sampling_with_filter(
# last round filtering with sigma_fn # last round filtering with sigma_fn
if (alpha_thre > 0 or early_stop_eps > 0) and (sigma_fn is not None): if (alpha_thre > 0 or early_stop_eps > 0) and (sigma_fn is not None):
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long()) sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert ( assert (
sigmas.shape == t_starts.shape sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape) ), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
......
...@@ -78,7 +78,7 @@ def rendering( ...@@ -78,7 +78,7 @@ def rendering(
# Query sigma/alpha and color with gradients # Query sigma/alpha and color with gradients
if rgb_sigma_fn is not None: if rgb_sigma_fn is not None:
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long()) rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format( assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape rgbs.shape
) )
...@@ -94,7 +94,7 @@ def rendering( ...@@ -94,7 +94,7 @@ def rendering(
n_rays=n_rays, n_rays=n_rays,
) )
elif rgb_alpha_fn is not None: elif rgb_alpha_fn is not None:
rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices.long()) rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format( assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape rgbs.shape
) )
...@@ -143,7 +143,7 @@ def accumulate_along_rays( ...@@ -143,7 +143,7 @@ def accumulate_along_rays(
Args: Args:
weights: Volumetric rendering weights for those samples. Tensor with shape \ weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,). (n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples). ray_indices: Ray index of each sample. LongTensor with shape (n_samples).
values: The values to be accmulated. Tensor with shape (n_samples, D). If \ values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None. None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If \ n_rays: Total number of rays. This will decide the shape of the ouputs. If \
...@@ -190,9 +190,10 @@ def accumulate_along_rays( ...@@ -190,9 +190,10 @@ def accumulate_along_rays(
n_rays = int(ray_indices.max()) + 1 n_rays = int(ray_indices.max()) + 1
# assert n_rays > ray_indices.max() # assert n_rays > ray_indices.max()
ray_indices = ray_indices.int() index = ray_indices[:, None].expand(-1, src.shape[-1])
index = ray_indices[:, None].long().expand(-1, src.shape[-1]) outputs = torch.zeros(
outputs = torch.zeros((n_rays, src.shape[-1]), device=weights.device) (n_rays, src.shape[-1]), device=src.device, dtype=src.dtype
)
outputs.scatter_add_(0, index, src) outputs.scatter_add_(0, index, src)
return outputs return outputs
...@@ -524,7 +525,7 @@ class _RenderingTransmittanceFromDensityCUB(torch.autograd.Function): ...@@ -524,7 +525,7 @@ class _RenderingTransmittanceFromDensityCUB(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, ray_indices, t_starts, t_ends, sigmas): def forward(ctx, ray_indices, t_starts, t_ends, sigmas):
ray_indices = ray_indices.contiguous().int() ray_indices = ray_indices.contiguous()
t_starts = t_starts.contiguous() t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous() t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous() sigmas = sigmas.contiguous()
...@@ -550,7 +551,7 @@ class _RenderingTransmittanceFromDensityNaive(torch.autograd.Function): ...@@ -550,7 +551,7 @@ class _RenderingTransmittanceFromDensityNaive(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas): def forward(ctx, packed_info, t_starts, t_ends, sigmas):
packed_info = packed_info.contiguous().int() packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous() t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous() t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous() sigmas = sigmas.contiguous()
...@@ -576,7 +577,7 @@ class _RenderingTransmittanceFromAlphaCUB(torch.autograd.Function): ...@@ -576,7 +577,7 @@ class _RenderingTransmittanceFromAlphaCUB(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, ray_indices, alphas): def forward(ctx, ray_indices, alphas):
ray_indices = ray_indices.contiguous().int() ray_indices = ray_indices.contiguous()
alphas = alphas.contiguous() alphas = alphas.contiguous()
transmittance = _C.transmittance_from_alpha_forward_cub( transmittance = _C.transmittance_from_alpha_forward_cub(
ray_indices, alphas ray_indices, alphas
...@@ -600,7 +601,7 @@ class _RenderingTransmittanceFromAlphaNaive(torch.autograd.Function): ...@@ -600,7 +601,7 @@ class _RenderingTransmittanceFromAlphaNaive(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, alphas): def forward(ctx, packed_info, alphas):
packed_info = packed_info.contiguous().int() packed_info = packed_info.contiguous()
alphas = alphas.contiguous() alphas = alphas.contiguous()
transmittance = _C.transmittance_from_alpha_forward_naive( transmittance = _C.transmittance_from_alpha_forward_naive(
packed_info, alphas packed_info, alphas
...@@ -624,7 +625,7 @@ class _RenderingWeightFromDensityNaive(torch.autograd.Function): ...@@ -624,7 +625,7 @@ class _RenderingWeightFromDensityNaive(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas): def forward(ctx, packed_info, t_starts, t_ends, sigmas):
packed_info = packed_info.contiguous().int() packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous() t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous() t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous() sigmas = sigmas.contiguous()
...@@ -652,7 +653,7 @@ class _RenderingWeightFromAlphaNaive(torch.autograd.Function): ...@@ -652,7 +653,7 @@ class _RenderingWeightFromAlphaNaive(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, packed_info, alphas): def forward(ctx, packed_info, alphas):
packed_info = packed_info.contiguous().int() packed_info = packed_info.contiguous()
alphas = alphas.contiguous() alphas = alphas.contiguous()
weights = _C.weight_from_alpha_forward_naive(packed_info, alphas) weights = _C.weight_from_alpha_forward_naive(packed_info, alphas)
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
......
...@@ -97,7 +97,7 @@ def main(): ...@@ -97,7 +97,7 @@ def main():
cpu_t, cuda_t, cuda_bytes = profiler(fn) cpu_t, cuda_t, cuda_bytes = profiler(fn)
print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB") print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB")
packed_info = nerfacc.pack_info(ray_indices, n_rays=batch_size).int() packed_info = nerfacc.pack_info(ray_indices, n_rays=batch_size)
fn = ( fn = (
lambda: nerfacc.vol_rendering._RenderingDensity.apply( lambda: nerfacc.vol_rendering._RenderingDensity.apply(
packed_info, t_starts, t_ends, sigmas, 0 packed_info, t_starts, t_ends, sigmas, 0
......
...@@ -39,7 +39,7 @@ def test_marching_with_grid(): ...@@ -39,7 +39,7 @@ def test_marching_with_grid():
far_plane=1.0, far_plane=1.0,
render_step_size=1e-2, render_step_size=1e-2,
) )
ray_indices = ray_indices.long() ray_indices = ray_indices
samples = ( samples = (
rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0 rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0
) )
......
...@@ -18,7 +18,7 @@ eps = 1e-6 ...@@ -18,7 +18,7 @@ eps = 1e-6
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_visibility(): def test_render_visibility():
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,) ) # (samples,)
alphas = torch.tensor( alphas = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
...@@ -48,7 +48,7 @@ def test_render_visibility(): ...@@ -48,7 +48,7 @@ def test_render_visibility():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_weight_from_alpha(): def test_render_weight_from_alpha():
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,) ) # (samples,)
alphas = torch.tensor( alphas = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
...@@ -71,7 +71,7 @@ def test_render_weight_from_alpha(): ...@@ -71,7 +71,7 @@ def test_render_weight_from_alpha():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_weight_from_density(): def test_render_weight_from_density():
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,) ) # (samples,)
sigmas = torch.rand( sigmas = torch.rand(
(ray_indices.shape[0], 1), device=device (ray_indices.shape[0], 1), device=device
...@@ -92,7 +92,7 @@ def test_render_weight_from_density(): ...@@ -92,7 +92,7 @@ def test_render_weight_from_density():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_accumulate_along_rays(): def test_accumulate_along_rays():
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (n_rays,) ) # (n_rays,)
weights = torch.tensor( weights = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device [0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
...@@ -116,7 +116,7 @@ def test_rendering(): ...@@ -116,7 +116,7 @@ def test_rendering():
return torch.hstack([t_starts] * 3), t_starts return torch.hstack([t_starts] * 3), t_starts
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,) ) # (samples,)
sigmas = torch.rand( sigmas = torch.rand(
(ray_indices.shape[0], 1), device=device (ray_indices.shape[0], 1), device=device
...@@ -136,7 +136,7 @@ def test_rendering(): ...@@ -136,7 +136,7 @@ def test_rendering():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_grads(): def test_grads():
ray_indices = torch.tensor( ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,) ) # (samples,)
packed_info = torch.tensor( packed_info = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device [[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment