Commit 62afb4ba authored by Ruilong Li's avatar Ruilong Li
Browse files

Revert "unpack_info with n_samples"

This reverts commit 301b4dfa.
parent 301b4dfa
......@@ -81,8 +81,7 @@ __global__ void unpack_data_kernel(
return;
}
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples)
torch::Tensor unpack_info(const torch::Tensor packed_info)
{
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
......@@ -91,7 +90,7 @@ torch::Tensor unpack_info(
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kInt32));
......
......@@ -45,7 +45,7 @@ std::vector<torch::Tensor> ray_marching(
const float cone_angle);
torch::Tensor unpack_info(
const torch::Tensor packed_info, const int n_samples);
const torch::Tensor packed_info);
torch::Tensor unpack_info_to_mask(
const torch::Tensor packed_info, const int n_samples);
......
......@@ -44,7 +44,7 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
@torch.no_grad()
def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
def unpack_info(packed_info: Tensor) -> Tensor:
"""Unpack `packed_info` to `ray_indices`. Useful for converting per ray data to per sample data.
Note:
......@@ -53,7 +53,6 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
Args:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
n_samples: Total number of samples.
Returns:
Ray index of each sample. LongTensor with shape (n_sample).
......@@ -72,7 +71,7 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
# torch.Size([128, 2]) torch.Size([115200, 1]) torch.Size([115200, 1])
print(packed_info.shape, t_starts.shape, t_ends.shape)
# Unpack per-ray info to per-sample info.
ray_indices = unpack_info(packed_info, t_starts.shape[0])
ray_indices = unpack_info(packed_info)
# torch.Size([115200]) torch.int64
print(ray_indices.shape, ray_indices.dtype)
......@@ -81,7 +80,7 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous().int(), n_samples)
ray_indices = _C.unpack_info(packed_info.contiguous().int())
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices.long()
......
......@@ -128,7 +128,7 @@ def ray_marching(
)
# Convert t_starts and t_ends to sample locations.
ray_indices = unpack_info(packed_info, t_starts.shape[0])
ray_indices = unpack_info(packed_info)
t_mid = (t_starts + t_ends) / 2.0
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
......@@ -197,7 +197,7 @@ def ray_marching(
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
ray_indices = unpack_info(packed_info, t_starts.shape[0])
ray_indices = unpack_info(packed_info)
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
......
......@@ -96,7 +96,7 @@ def rendering(
)
n_rays = packed_info.shape[0]
ray_indices = unpack_info(packed_info, t_starts.shape[0])
ray_indices = unpack_info(packed_info)
# Query sigma/alpha and color with gradients
if rgb_sigma_fn is not None:
......@@ -160,7 +160,7 @@ def accumulate_along_rays(
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples). \
It can be obtained from `unpack_info(packed_info, n_samples)`.
It can be obtained from `unpack_info(packed_info)`.
values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If \
......
......@@ -31,7 +31,7 @@ def test_unpack_info():
ray_indices_tgt = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
)
ray_indices = unpack_info(packed_info, 5)
ray_indices = unpack_info(packed_info)
assert torch.allclose(ray_indices, ray_indices_tgt)
......
......@@ -39,7 +39,7 @@ def test_marching_with_grid():
far_plane=1.0,
render_step_size=1e-2,
)
ray_indices = unpack_info(packed_info, t_starts.shape[0]).long()
ray_indices = unpack_info(packed_info).long()
samples = (
rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment