Commit 301b4dfa authored by Ruilong Li's avatar Ruilong Li
Browse files

unpack_info with n_samples

parent 2d8d6b43
......@@ -81,7 +81,8 @@ __global__ void unpack_data_kernel(
return;
}
torch::Tensor unpack_info(const torch::Tensor packed_info)
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
......@@ -90,7 +91,7 @@ torch::Tensor unpack_info(const torch::Tensor packed_info)
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kInt32));
......
......@@ -45,7 +45,7 @@ std::vector<torch::Tensor> ray_marching(
const float cone_angle);
torch::Tensor unpack_info(
const torch::Tensor packed_info);
const torch::Tensor packed_info, const int n_samples);
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
......
......@@ -44,7 +44,7 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
@torch.no_grad()
def unpack_info(packed_info: Tensor) -> Tensor:
def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
"""Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data.
Note:
......@@ -53,6 +53,7 @@ def unpack_info(packed_info: Tensor) -> Tensor:
Args:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
n_samples: Total number of samples.
Returns:
Ray index of each sample. LongTensor with shape (n_sample).
......@@ -71,7 +72,7 @@ def unpack_info(packed_info: Tensor) -> Tensor:
# torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1])
print(packed_info.shape, t_starts.shape, t_ends.shape)
# Unpack per-ray info to per-sample info.
ray_indices = unpack_info(packed_info)
ray_indices = unpack_info(packed_info, t_starts.shape[0])
# torch.Size([115200]) torch.int64
print(ray_indices.shape, ray_indices.dtype)
......@@ -80,7 +81,7 @@ def unpack_info(packed_info: Tensor) -> Tensor:
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous().int())
ray_indices = _C.unpack_info(packed_info.contiguous().int(), n_samples)
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices.long()
......
......@@ -128,7 +128,7 @@ def ray_marching(
)
# Convert t_starts and t_ends to sample locations.
ray_indices = unpack_info(packed_info)
ray_indices = unpack_info(packed_info, t_starts.shape[0])
t_mid = (t_starts + t_ends) / 2.0
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
......@@ -197,7 +197,7 @@ def ray_marching(
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
ray_indices = unpack_info(packed_info)
ray_indices = unpack_info(packed_info, t_starts.shape[0])
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
......
......@@ -96,7 +96,7 @@ def rendering(
)
n_rays = packed_info.shape[0]
ray_indices = unpack_info(packed_info)
ray_indices = unpack_info(packed_info, t_starts.shape[0])
# Query sigma/alpha and color with gradients
if rgb_sigma_fn is not None:
......@@ -160,7 +160,7 @@ def accumulate_along_rays(
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \
It can be obtained from `unpack_info(packed_info)`.
It can be obtained from `unpack_info(packed_info, n_samples)`.
values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If \
......
......@@ -31,7 +31,7 @@ def test_unpack_info():
ray_indices_tgt = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
)
ray_indices = unpack_info(packed_info)
ray_indices = unpack_info(packed_info, 5)
assert torch.allclose(ray_indices, ray_indices_tgt)
......
......@@ -39,7 +39,7 @@ def test_marching_with_grid():
far_plane=1.0,
render_step_size=1e-2,
)
ray_indices = unpack_info(packed_info).long()
ray_indices = unpack_info(packed_info, t_starts.shape[0]).long()
samples = (
rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment