Commit ea07af8e authored by Ruilong Li's avatar Ruilong Li
Browse files

inclusive sum in one function

parent a0792e88
......@@ -19,12 +19,6 @@ from .volrend import (
render_weight_from_density,
rendering,
)
from .scan_cub import (
exclusive_prod_cub,
exclusive_sum_cub,
inclusive_prod_cub,
inclusive_sum_cub,
)
__all__ = [
"__version__",
......
......@@ -30,12 +30,17 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
exclusive_prod_forward = _make_lazy_cuda_func("exclusive_prod_forward")
exclusive_prod_backward = _make_lazy_cuda_func("exclusive_prod_backward")
is_cub_available = _make_lazy_cuda_func("is_cub_available")
inclusive_sum_cub = _make_lazy_cuda_func("inclusive_sum_cub")
exclusive_sum_cub = _make_lazy_cuda_func("exclusive_sum_cub")
inclusive_prod_cub_forward = _make_lazy_cuda_func("inclusive_prod_cub_forward")
inclusive_prod_cub_backward = _make_lazy_cuda_func("inclusive_prod_cub_backward")
inclusive_prod_cub_backward = _make_lazy_cuda_func(
"inclusive_prod_cub_backward"
)
exclusive_prod_cub_forward = _make_lazy_cuda_func("exclusive_prod_cub_forward")
exclusive_prod_cub_backward = _make_lazy_cuda_func("exclusive_prod_cub_backward")
exclusive_prod_cub_backward = _make_lazy_cuda_func(
"exclusive_prod_cub_backward"
)
# pdf
importance_sampling = _make_lazy_cuda_func("importance_sampling")
......
......@@ -38,6 +38,7 @@ torch::Tensor exclusive_prod_backward(
torch::Tensor outputs,
torch::Tensor grad_outputs);
bool is_cub_available();
torch::Tensor inclusive_sum_cub(
torch::Tensor ray_indices,
torch::Tensor inputs,
......@@ -131,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward);
_REG_FUNC(is_cub_available);
_REG_FUNC(inclusive_sum_cub);
_REG_FUNC(exclusive_sum_cub);
_REG_FUNC(inclusive_prod_cub_forward);
......
......@@ -56,6 +56,14 @@ inline void inclusive_prod_by_key(
}
#endif
bool is_cub_available() {
#if CUB_SUPPORTS_SCAN_BY_KEY()
return true;
#else
return false;
#endif
}
torch::Tensor inclusive_sum_cub(
torch::Tensor indices,
torch::Tensor inputs,
......
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import warnings
from typing import Optional
import torch
from torch import Tensor
from . import cuda as _C
from .pack import pack_info
def inclusive_sum(
inputs: Tensor, packed_info: Optional[Tensor] = None
inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor:
"""Inclusive Sum that supports flattened tensor.
......@@ -20,11 +24,12 @@ def inclusive_sum(
Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The inclusive sum with the same shape as the input tensor.
......@@ -39,22 +44,43 @@ def inclusive_sum(
tensor([ 1., 3., 3., 7., 12., 6., 13., 21., 30.], device='cuda:0')
"""
if packed_info is None:
# Batched inclusive sum on the last dimension.
outputs = torch.cumsum(inputs, dim=-1)
if indices is not None and packed_info is not None:
raise ValueError(
"Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _InclusiveSumCUB.apply(indices, inputs)
else:
# Flattened inclusive sum.
warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened."
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _InclusiveSum.apply(chunk_starts, chunk_cnts, inputs, False)
if indices is None and packed_info is None:
# Batched inclusive sum on the last dimension.
outputs = torch.cumsum(inputs, dim=-1)
return outputs
def exclusive_sum(
inputs: Tensor, packed_info: Optional[Tensor] = None
inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor:
"""Exclusive Sum that supports flattened tensor.
......@@ -62,11 +88,12 @@ def exclusive_sum(
Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The exclusive sum with the same shape as the input tensor.
......@@ -81,27 +108,47 @@ def exclusive_sum(
tensor([ 0., 1., 0., 3., 7., 0., 6., 13., 21.], device='cuda:0')
"""
if packed_info is None:
# Batched exclusive sum on the last dimension.
outputs = torch.cumsum(
torch.cat(
[torch.zeros_like(inputs[..., :1]), inputs[..., :-1]], dim=-1
),
dim=-1,
if indices is not None and packed_info is not None:
raise ValueError(
"Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _ExclusiveSumCUB.apply(indices, inputs)
else:
# Flattened exclusive sum.
warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened."
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _ExclusiveSum.apply(chunk_starts, chunk_cnts, inputs, False)
if indices is None and packed_info is None:
# Batched exclusive sum on the last dimension.
outputs = torch.cumsum(
torch.cat(
[torch.zeros_like(inputs[..., :1]), inputs[..., :-1]], dim=-1
),
dim=-1,
)
return outputs
def inclusive_prod(
inputs: Tensor, packed_info: Optional[Tensor] = None
inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor:
"""Inclusive Product that supports flattened tensor.
......@@ -111,11 +158,12 @@ def inclusive_prod(
Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The inclusive product with the same shape as the input tensor.
......@@ -130,22 +178,43 @@ def inclusive_prod(
tensor([1., 2., 3., 12., 60., 6., 42., 336., 3024.], device='cuda:0')
"""
if packed_info is None:
# Batched inclusive product on the last dimension.
outputs = torch.cumprod(inputs, dim=-1)
if indices is not None and packed_info is not None:
raise ValueError(
"Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _InclusiveProdCUB.apply(indices, inputs)
else:
# Flattened inclusive product.
warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened."
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _InclusiveProd.apply(chunk_starts, chunk_cnts, inputs)
if indices is None and packed_info is None:
# Batched inclusive product on the last dimension.
outputs = torch.cumprod(inputs, dim=-1)
return outputs
def exclusive_prod(
inputs: Tensor, packed_info: Optional[Tensor] = None
inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor:
"""Exclusive Product that supports flattened tensor.
......@@ -153,11 +222,12 @@ def exclusive_prod(
Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The exclusive product with the same shape as the input tensor.
......@@ -173,16 +243,42 @@ def exclusive_prod(
tensor([1., 1., 1., 3., 12., 1., 6., 42., 336.], device='cuda:0')
"""
if packed_info is None:
if indices is not None and packed_info is not None:
raise ValueError(
"Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _ExclusiveProdCUB.apply(indices, inputs)
else:
warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened."
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _ExclusiveProd.apply(chunk_starts, chunk_cnts, inputs)
if indices is None and packed_info is None:
# Batched exclusive product on the last dimension.
outputs = torch.cumprod(
torch.cat(
[torch.ones_like(inputs[..., :1]), inputs[..., :-1]], dim=-1
),
dim=-1,
)
else:
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _ExclusiveProd.apply(chunk_starts, chunk_cnts, inputs)
return outputs
......@@ -286,3 +382,87 @@ class _ExclusiveProd(torch.autograd.Function):
chunk_starts, chunk_cnts, inputs, outputs, grad_outputs
)
return None, None, grad_inputs
class _InclusiveSumCUB(torch.autograd.Function):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.inclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _ExclusiveSumCUB(torch.autograd.Function):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.exclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _InclusiveProdCUB(torch.autograd.Function):
"""Inclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.inclusive_prod_cub_backward(
indices, inputs, outputs, grad_outputs
)
return None, grad_inputs
class _ExclusiveProdCUB(torch.autograd.Function):
"""Exclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.exclusive_prod_cub_backward(
indices, inputs, outputs, grad_outputs
)
return None, grad_inputs
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import torch
from torch import Tensor
from . import cuda as _C
def inclusive_sum_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Inclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _InclusiveSum.apply(indices, inputs)
return outputs
def exclusive_sum_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Exclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _ExclusiveSum.apply(indices, inputs)
return outputs
def inclusive_prod_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Inclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _InclusiveProd.apply(indices, inputs)
return outputs
def exclusive_prod_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Exclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _ExclusiveProd.apply(indices, inputs)
return outputs
class _InclusiveSum(torch.autograd.Function):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.inclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _ExclusiveSum(torch.autograd.Function):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.exclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _InclusiveProd(torch.autograd.Function):
"""Inclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.inclusive_prod_cub_backward(indices, inputs, outputs, grad_outputs)
return None, grad_inputs
class _ExclusiveProd(torch.autograd.Function):
"""Exclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.exclusive_prod_cub_backward(indices, inputs, outputs, grad_outputs)
return None, grad_inputs
......@@ -7,7 +7,6 @@ device = "cuda:0"
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_sum():
from nerfacc.scan import inclusive_sum
from nerfacc.scan_cub import inclusive_sum_cub
torch.manual_seed(42)
......@@ -34,7 +33,7 @@ def test_inclusive_sum():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = inclusive_sum_cub(flatten_data, indices)
outputs3 = inclusive_sum(flatten_data, indices=indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
......@@ -49,7 +48,6 @@ def test_inclusive_sum():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_sum():
from nerfacc.scan import exclusive_sum
from nerfacc.scan_cub import exclusive_sum_cub
torch.manual_seed(42)
......@@ -76,7 +74,7 @@ def test_exclusive_sum():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = exclusive_sum_cub(flatten_data, indices)
outputs3 = exclusive_sum(flatten_data, indices=indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
......@@ -93,7 +91,6 @@ def test_exclusive_sum():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_prod():
from nerfacc.scan import inclusive_prod
from nerfacc.scan_cub import inclusive_prod_cub
torch.manual_seed(42)
......@@ -120,7 +117,7 @@ def test_inclusive_prod():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = inclusive_prod_cub(flatten_data, indices)
outputs3 = inclusive_prod(flatten_data, indices=indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
......@@ -135,7 +132,6 @@ def test_inclusive_prod():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_prod():
from nerfacc.scan import exclusive_prod
from nerfacc.scan_cub import exclusive_prod_cub
torch.manual_seed(42)
......@@ -162,7 +158,7 @@ def test_exclusive_prod():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = exclusive_prod_cub(flatten_data, indices)
outputs3 = exclusive_prod(flatten_data, indices=indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
......@@ -175,10 +171,11 @@ def test_exclusive_prod():
assert torch.allclose(outputs1, outputs3)
assert torch.allclose(grad1, grad3)
def profile():
import tqdm
from nerfacc.scan import inclusive_sum
from nerfacc.scan_cub import inclusive_sum_cub
torch.manual_seed(42)
......@@ -202,7 +199,7 @@ def profile():
indices = indices.flatten()
torch.cuda.synchronize()
for _ in tqdm.trange(2000):
outputs3 = inclusive_sum_cub(flatten_data, indices)
outputs3 = inclusive_sum(flatten_data, indices=indices)
outputs3.sum().backward()
......@@ -211,4 +208,4 @@ if __name__ == "__main__":
test_exclusive_sum()
test_inclusive_prod()
test_exclusive_prod()
# profile()
\ No newline at end of file
profile()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment