Unverified Commit eecfd44b authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Reduce precision conversion when packing (#124)

* Reduce conversion. Assume that ray_indices should always be long and pack_info always int

* Remove some more casting

* More casting fixes

* Remove unecessary long casting

* Fix changes in precision at python level, ray_pdf_query test is already broken since it was reverted

* Support long when rendering transmittance

* Support long

* Woops
parent fe75cb86
......@@ -58,7 +58,6 @@ def render_image(
num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices, net=None):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
......@@ -68,7 +67,6 @@ def render_image(
return radiance_field.query_density(positions)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
......
......@@ -48,7 +48,6 @@ def render_image(
num_rays, _ = rays_shape
def sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
......@@ -63,7 +62,6 @@ def render_image(
return radiance_field.query_density(positions)
def rgb_sigma_fn(t_starts, t_ends, ray_indices):
ray_indices = ray_indices.long()
t_origins = chunk_rays.origins[ray_indices]
t_dirs = chunk_rays.viewdirs[ray_indices]
positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
......
......@@ -9,7 +9,7 @@ __global__ void unpack_info_kernel(
const int n_rays,
const int *packed_info,
// output
int *ray_indices)
long *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
......@@ -92,12 +92,12 @@ torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
// int n_samples = packed_info[n_rays - 1].sum(0).item<int>();
torch::Tensor ray_indices = torch::empty(
{n_samples}, packed_info.options().dtype(torch::kInt32));
{n_samples}, packed_info.options().dtype(torch::kLong));
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<int>());
ray_indices.data_ptr<long>());
return ray_indices;
}
......
......@@ -95,7 +95,7 @@ __global__ void ray_marching_kernel(
// first round outputs
int *num_steps,
// second round outputs
int *ray_indices,
long *ray_indices,
float *t_starts,
float *t_ends)
{
......@@ -259,7 +259,7 @@ std::vector<torch::Tensor> ray_marching(
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options().dtype(torch::kLong));
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
......@@ -279,7 +279,7 @@ std::vector<torch::Tensor> ray_marching(
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int>(),
ray_indices.data_ptr<long>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
......
......@@ -20,8 +20,8 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub ExclusiveSumByKey does not support more than INT_MAX elements");
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
......@@ -30,8 +30,8 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
"cub ExclusiveScanByKey does not support more than INT_MAX elements");
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
......@@ -60,7 +60,7 @@ torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
ray_indices.data_ptr<int>(),
ray_indices.data_ptr<long>(),
sigmas_dt.data_ptr<float>(),
sigmas_dt_cumsum.data_ptr<float>(),
n_samples);
......@@ -97,7 +97,7 @@ torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int>() + n_samples),
thrust::make_reverse_iterator(ray_indices.data_ptr<long>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
......@@ -123,7 +123,7 @@ torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor transmittance = torch::empty_like(alphas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
ray_indices.data_ptr<int>(),
ray_indices.data_ptr<long>(),
(1.0f - alphas).data_ptr<float>(),
transmittance.data_ptr<float>(),
n_samples);
......@@ -154,7 +154,7 @@ torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<int>() + n_samples),
thrust::make_reverse_iterator(ray_indices.data_ptr<long>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
......
......@@ -300,18 +300,14 @@ def _meshgrid3d(
"""Create 3D grid coordinates."""
assert len(res) == 3
res = res.tolist()
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
torch.arange(res[2]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
return torch.stack(
torch.meshgrid(
[
torch.arange(res[0], dtype=torch.long),
torch.arange(res[1], dtype=torch.long),
torch.arange(res[2], dtype=torch.long),
],
indexing="ij",
),
dim=-1,
).to(device)
......@@ -37,8 +37,8 @@ def pack_data(data: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:
), "mask must be with shape of (n_rays, n_samples)."
assert mask.dtype == torch.bool, "mask must be a boolean tensor."
packed_data = data[mask]
num_steps = mask.long().sum(dim=-1)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.long)
num_steps = mask.sum(dim=-1, dtype=torch.int32)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int32)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
return packed_data, packed_info
......@@ -55,26 +55,26 @@ def pack_info(ray_indices: Tensor, n_rays: int = None) -> Tensor:
Returns:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
"""
assert (
ray_indices.dim() == 1
), "ray_indices must be a 1D tensor with shape (n_samples)."
if ray_indices.is_cuda:
ray_indices = ray_indices.contiguous().int()
ray_indices = ray_indices
device = ray_indices.device
if n_rays is None:
n_rays = int(ray_indices.max()) + 1
# else:
# assert n_rays > ray_indices.max()
src = torch.ones_like(ray_indices)
src = torch.ones_like(ray_indices, dtype=torch.int)
num_steps = torch.zeros((n_rays,), device=device, dtype=torch.int)
num_steps.scatter_add_(0, ray_indices.long(), src)
num_steps.scatter_add_(0, ray_indices, src)
cum_steps = num_steps.cumsum(dim=0, dtype=torch.int)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1)
else:
raise NotImplementedError("Only support cuda inputs.")
return packed_info.int()
return packed_info
@torch.no_grad()
......@@ -86,7 +86,7 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
Args:
packed_info: Stores information on which samples belong to the same ray. \
See :func:`nerfacc.ray_marching` for details. Tensor with shape (n_rays, 2).
See :func:`nerfacc.ray_marching` for details. IntTensor with shape (n_rays, 2).
n_samples: Total number of samples.
Returns:
......@@ -115,10 +115,10 @@ def unpack_info(packed_info: Tensor, n_samples: int) -> Tensor:
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be a 2D tensor with shape (n_rays, 2)."
if packed_info.is_cuda:
ray_indices = _C.unpack_info(packed_info.contiguous().int(), n_samples)
ray_indices = _C.unpack_info(packed_info.contiguous(), n_samples)
else:
raise NotImplementedError("Only support cuda inputs.")
return ray_indices.long()
return ray_indices
def unpack_data(
......@@ -173,7 +173,7 @@ class _UnpackData(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info: Tensor, data: Tensor, n_samples: int):
# shape of the data should be (all_samples, D)
packed_info = packed_info.contiguous().int()
packed_info = packed_info.contiguous()
data = data.contiguous()
if ctx.needs_input_grad[1]:
ctx.save_for_backward(packed_info)
......
......@@ -193,13 +193,13 @@ def ray_marching(
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
alphas = alpha_fn(t_starts, t_ends, ray_indices)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
......
......@@ -127,7 +127,7 @@ def proposal_sampling_with_filter(
for proposal_fn, n_samples in zip(proposal_sigma_fns, proposal_n_samples):
# compute weights for resampling
sigmas = proposal_fn(t_starts, t_ends, ray_indices.long())
sigmas = proposal_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
......@@ -152,7 +152,7 @@ def proposal_sampling_with_filter(
# Rerun the proposal function **with** gradients on filtered samples.
if proposal_require_grads:
with torch.enable_grad():
sigmas = proposal_fn(t_starts, t_ends, ray_indices.long())
sigmas = proposal_fn(t_starts, t_ends, ray_indices)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
......@@ -168,7 +168,7 @@ def proposal_sampling_with_filter(
# last round filtering with sigma_fn
if (alpha_thre > 0 or early_stop_eps > 0) and (sigma_fn is not None):
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
sigmas = sigma_fn(t_starts, t_ends, ray_indices)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
......
......@@ -78,7 +78,7 @@ def rendering(
# Query sigma/alpha and color with gradients
if rgb_sigma_fn is not None:
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices.long())
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape
)
......@@ -94,7 +94,7 @@ def rendering(
n_rays=n_rays,
)
elif rgb_alpha_fn is not None:
rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices.long())
rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape
)
......@@ -143,7 +143,7 @@ def accumulate_along_rays(
Args:
weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples,).
ray_indices: Ray index of each sample. IntTensor with shape (n_samples).
ray_indices: Ray index of each sample. LongTensor with shape (n_samples).
values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If \
......@@ -190,9 +190,10 @@ def accumulate_along_rays(
n_rays = int(ray_indices.max()) + 1
# assert n_rays > ray_indices.max()
ray_indices = ray_indices.int()
index = ray_indices[:, None].long().expand(-1, src.shape[-1])
outputs = torch.zeros((n_rays, src.shape[-1]), device=weights.device)
index = ray_indices[:, None].expand(-1, src.shape[-1])
outputs = torch.zeros(
(n_rays, src.shape[-1]), device=src.device, dtype=src.dtype
)
outputs.scatter_add_(0, index, src)
return outputs
......@@ -524,7 +525,7 @@ class _RenderingTransmittanceFromDensityCUB(torch.autograd.Function):
@staticmethod
def forward(ctx, ray_indices, t_starts, t_ends, sigmas):
ray_indices = ray_indices.contiguous().int()
ray_indices = ray_indices.contiguous()
t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous()
......@@ -550,7 +551,7 @@ class _RenderingTransmittanceFromDensityNaive(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas):
packed_info = packed_info.contiguous().int()
packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous()
......@@ -576,7 +577,7 @@ class _RenderingTransmittanceFromAlphaCUB(torch.autograd.Function):
@staticmethod
def forward(ctx, ray_indices, alphas):
ray_indices = ray_indices.contiguous().int()
ray_indices = ray_indices.contiguous()
alphas = alphas.contiguous()
transmittance = _C.transmittance_from_alpha_forward_cub(
ray_indices, alphas
......@@ -600,7 +601,7 @@ class _RenderingTransmittanceFromAlphaNaive(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, alphas):
packed_info = packed_info.contiguous().int()
packed_info = packed_info.contiguous()
alphas = alphas.contiguous()
transmittance = _C.transmittance_from_alpha_forward_naive(
packed_info, alphas
......@@ -624,7 +625,7 @@ class _RenderingWeightFromDensityNaive(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, t_starts, t_ends, sigmas):
packed_info = packed_info.contiguous().int()
packed_info = packed_info.contiguous()
t_starts = t_starts.contiguous()
t_ends = t_ends.contiguous()
sigmas = sigmas.contiguous()
......@@ -652,7 +653,7 @@ class _RenderingWeightFromAlphaNaive(torch.autograd.Function):
@staticmethod
def forward(ctx, packed_info, alphas):
packed_info = packed_info.contiguous().int()
packed_info = packed_info.contiguous()
alphas = alphas.contiguous()
weights = _C.weight_from_alpha_forward_naive(packed_info, alphas)
if ctx.needs_input_grad[1]:
......
......@@ -97,7 +97,7 @@ def main():
cpu_t, cuda_t, cuda_bytes = profiler(fn)
print(f"{cpu_t:.2f} us, {cuda_t:.2f} us, {cuda_bytes / 1024 / 1024:.2f} MB")
packed_info = nerfacc.pack_info(ray_indices, n_rays=batch_size).int()
packed_info = nerfacc.pack_info(ray_indices, n_rays=batch_size)
fn = (
lambda: nerfacc.vol_rendering._RenderingDensity.apply(
packed_info, t_starts, t_ends, sigmas, 0
......
......@@ -39,7 +39,7 @@ def test_marching_with_grid():
far_plane=1.0,
render_step_size=1e-2,
)
ray_indices = ray_indices.long()
ray_indices = ray_indices
samples = (
rays_o[ray_indices] + rays_d[ray_indices] * (t_starts + t_ends) / 2.0
)
......
......@@ -18,7 +18,7 @@ eps = 1e-6
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_visibility():
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,)
alphas = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
......@@ -48,7 +48,7 @@ def test_render_visibility():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_weight_from_alpha():
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,)
alphas = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
......@@ -71,7 +71,7 @@ def test_render_weight_from_alpha():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_render_weight_from_density():
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,)
sigmas = torch.rand(
(ray_indices.shape[0], 1), device=device
......@@ -92,7 +92,7 @@ def test_render_weight_from_density():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_accumulate_along_rays():
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (n_rays,)
weights = torch.tensor(
[0.4, 0.3, 0.8, 0.8, 0.5], dtype=torch.float32, device=device
......@@ -116,7 +116,7 @@ def test_rendering():
return torch.hstack([t_starts] * 3), t_starts
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,)
sigmas = torch.rand(
(ray_indices.shape[0], 1), device=device
......@@ -136,7 +136,7 @@ def test_rendering():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_grads():
ray_indices = torch.tensor(
[0, 2, 2, 2, 2], dtype=torch.int32, device=device
[0, 2, 2, 2, 2], dtype=torch.int64, device=device
) # (samples,)
packed_info = torch.tensor(
[[0, 1], [1, 0], [1, 4]], dtype=torch.int32, device=device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment