Commit 301b4dfa authored by Ruilong Li's avatar Ruilong Li
Browse files

unpack_info with n_samples

parent 2d8d6b43
...@@ -81,7 +81,8 @@ __global__ void unpack_data_kernel( ...@@ -81,7 +81,8 @@ __global__ void unpack_data_kernel(
return; return;
} }
torch::Tensor unpack_info(const torch::Tensor packed_info) torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples)
{ {
DEVICE_GUARD(packed_info); DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info); CHECK_INPUT(packed_info);
...@@ -90,7 +91,7 @@ torch::Tensor unpack_info(const torch::Tensor packed_info) ...@@ -90,7 +91,7 @@ torch::Tensor unpack_info(const torch::Tensor packed_info)
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
int n_samples = packed_info[n_rays - 1].sum(0).item<int>(); // int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty( torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kInt32)); {n_samples}, packed_info.options().dtype(torch::kInt32));
......
...@@ -45,7 +45,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -45,7 +45,7 @@ std::vector<torch::Tensor> ray_marching(
const float cone_angle); const float cone_angle);
torch::Tensor unpack_info( torch::Tensor unpack_info(
const torch::Tensor packed_info); const torch::Tensor packed_info, const int n_samples);
torch::Tensor unpack_info_to_mask( torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples); const torch::Tensor packed_info, const int n_samples);
......
...@@ -44,7 +44,7 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: ...@@ -44,7 +44,7 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
@torch.no_grad() @torch.no_grad()
def unpack_info(packed_info: Tensor) -> Tensor: def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
"""Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data. """Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data.
Note: Note:
...@@ -53,6 +53,7 @@ def unpack_info(packed_info: Tensor) -> Tensor: ...@@ -53,6 +53,7 @@ def unpack_info(packed_info: Tensor) -> Tensor:
Args: Args:
packed_info: Stores information on which samples belong to the same ray. \ packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2). See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
n_samples: Total number of samples.
Returns: Returns:
Ray index of each sample. LongTensor with shape (n_sample). Ray index of each sample. LongTensor with shape (n_sample).
...@@ -71,7 +72,7 @@ def unpack_info(packed_info: Tensor) -> Tensor: ...@@ -71,7 +72,7 @@ def unpack_info(packed_info: Tensor) -> Tensor:
# torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1]) # torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1])
print(packed_info.shape, t_starts.shape, t_ends.shape) print(packed_info.shape, t_starts.shape, t_ends.shape)
# Unpack per-ray info to per-sample info. # Unpack per-ray info to per-sample info.
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info, t_starts.shape[0])
# torch.Size([115200]) torch.int64 # torch.Size([115200]) torch.int64
print(ray_indices.shape, ray_indices.dtype) print(ray_indices.shape, ray_indices.dtype)
...@@ -80,7 +81,7 @@ def unpack_info(packed_info: Tensor) -> Tensor: ...@@ -80,7 +81,7 @@ def unpack_info(packed_info: Tensor) -> Tensor:
packed_info.dim() == 2 and packed_info.shape[-1] == 2 packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)." ), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda: if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous().int()) ray_indices = _C.unpack_info(packed_info.contiguous().int(), n_samples)
else: else:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
return ray_indices.long() return ray_indices.long()
......
...@@ -128,7 +128,7 @@ def ray_marching( ...@@ -128,7 +128,7 @@ def ray_marching(
) )
# Convert t_starts and t_ends to sample locations. # Convert t_starts and t_ends to sample locations.
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info, t_starts.shape[0])
t_mid = (t_starts + t_ends) / 2.0 t_mid = (t_starts + t_ends) / 2.0
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
...@@ -197,7 +197,7 @@ def ray_marching( ...@@ -197,7 +197,7 @@ def ray_marching(
# skip invisible space # skip invisible space
if sigma_fn is not None or alpha_fn is not None: if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients # Query sigma without gradients
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info, t_starts.shape[0])
if sigma_fn is not None: if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long()) sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert ( assert (
......
...@@ -96,7 +96,7 @@ def rendering( ...@@ -96,7 +96,7 @@ def rendering(
) )
n_rays = packed_info.shape[0] n_rays = packed_info.shape[0]
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info, t_starts.shape[0])
# Query sigma/alpha and color with gradients # Query sigma/alpha and color with gradients
if rgb_sigma_fn is not None: if rgb_sigma_fn is not None:
...@@ -160,7 +160,7 @@ def accumulate_along_rays( ...@@ -160,7 +160,7 @@ def accumulate_along_rays(
weights: Volumetric rendering weights for those samples. Tensor with shape \ weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,). (n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \ ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \
It can be obtained from `unpack_info(packed_info)`. It can be obtained from `unpack_info(packed_info, n_samples)`.
values: The values to be accmulated. Tensor with shape (n_samples, D). If \ values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None. None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If \ n_rays: Total number of rays. This will decide the shape of the ouputs. If \
......
...@@ -31,7 +31,7 @@ def test_unpack_info(): ...@@ -31,7 +31,7 @@ def test_unpack_info():
ray_indices_tgt = torch.tensor( ray_indices_tgt = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int64, device=device [0, 2, 2, 2, 2], dtype=torch.int64, device=device
) )
ray_indices = unpack_info(packed_info) ray_indices = unpack_info(packed_info, 5)
assert torch.allclose(ray_indices, ray_indices_tgt) assert torch.allclose(ray_indices, ray_indices_tgt)
......
...@@ -39,7 +39,7 @@ def test_marching_with_grid(): ...@@ -39,7 +39,7 @@ def test_marching_with_grid():
far_plane=1.0, far_plane=1.0,
render_step_size=1e-2, render_step_size=1e-2,
) )
ray_indices = unpack_info(packed_info).long() ray_indices = unpack_info(packed_info, t_starts.shape[0]).long()
samples = ( samples = (
rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0 rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment