Commit ea07af8e authored by Ruilong Li's avatar Ruilong Li
Browse files

inclusive sum in one function

parent a0792e88
...@@ -19,12 +19,6 @@ from .volrend import ( ...@@ -19,12 +19,6 @@ from .volrend import (
render_weight_from_density, render_weight_from_density,
rendering, rendering,
) )
from .scan_cub import (
exclusive_prod_cub,
exclusive_sum_cub,
inclusive_prod_cub,
inclusive_sum_cub,
)
__all__ = [ __all__ = [
"__version__", "__version__",
......
...@@ -30,12 +30,17 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward") ...@@ -30,12 +30,17 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
exclusive_prod_forward = _make_lazy_cuda_func("exclusive_prod_forward") exclusive_prod_forward = _make_lazy_cuda_func("exclusive_prod_forward")
exclusive_prod_backward = _make_lazy_cuda_func("exclusive_prod_backward") exclusive_prod_backward = _make_lazy_cuda_func("exclusive_prod_backward")
is_cub_available = _make_lazy_cuda_func("is_cub_available")
inclusive_sum_cub = _make_lazy_cuda_func("inclusive_sum_cub") inclusive_sum_cub = _make_lazy_cuda_func("inclusive_sum_cub")
exclusive_sum_cub = _make_lazy_cuda_func("exclusive_sum_cub") exclusive_sum_cub = _make_lazy_cuda_func("exclusive_sum_cub")
inclusive_prod_cub_forward = _make_lazy_cuda_func("inclusive_prod_cub_forward") inclusive_prod_cub_forward = _make_lazy_cuda_func("inclusive_prod_cub_forward")
inclusive_prod_cub_backward = _make_lazy_cuda_func("inclusive_prod_cub_backward") inclusive_prod_cub_backward = _make_lazy_cuda_func(
"inclusive_prod_cub_backward"
)
exclusive_prod_cub_forward = _make_lazy_cuda_func("exclusive_prod_cub_forward") exclusive_prod_cub_forward = _make_lazy_cuda_func("exclusive_prod_cub_forward")
exclusive_prod_cub_backward = _make_lazy_cuda_func("exclusive_prod_cub_backward") exclusive_prod_cub_backward = _make_lazy_cuda_func(
"exclusive_prod_cub_backward"
)
# pdf # pdf
importance_sampling = _make_lazy_cuda_func("importance_sampling") importance_sampling = _make_lazy_cuda_func("importance_sampling")
......
...@@ -38,6 +38,7 @@ torch::Tensor exclusive_prod_backward( ...@@ -38,6 +38,7 @@ torch::Tensor exclusive_prod_backward(
torch::Tensor outputs, torch::Tensor outputs,
torch::Tensor grad_outputs); torch::Tensor grad_outputs);
bool is_cub_available();
torch::Tensor inclusive_sum_cub( torch::Tensor inclusive_sum_cub(
torch::Tensor ray_indices, torch::Tensor ray_indices,
torch::Tensor inputs, torch::Tensor inputs,
...@@ -131,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -131,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC(exclusive_prod_forward); _REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward); _REG_FUNC(exclusive_prod_backward);
_REG_FUNC(is_cub_available);
_REG_FUNC(inclusive_sum_cub); _REG_FUNC(inclusive_sum_cub);
_REG_FUNC(exclusive_sum_cub); _REG_FUNC(exclusive_sum_cub);
_REG_FUNC(inclusive_prod_cub_forward); _REG_FUNC(inclusive_prod_cub_forward);
......
...@@ -56,6 +56,14 @@ inline void inclusive_prod_by_key( ...@@ -56,6 +56,14 @@ inline void inclusive_prod_by_key(
} }
#endif #endif
bool is_cub_available() {
#if CUB_SUPPORTS_SCAN_BY_KEY()
return true;
#else
return false;
#endif
}
torch::Tensor inclusive_sum_cub( torch::Tensor inclusive_sum_cub(
torch::Tensor indices, torch::Tensor indices,
torch::Tensor inputs, torch::Tensor inputs,
......
""" """
Copyright (c) 2022 Ruilong Li, UC Berkeley. Copyright (c) 2022 Ruilong Li, UC Berkeley.
""" """
import warnings
from typing import Optional from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from . import cuda as _C from . import cuda as _C
from .pack import pack_info
def inclusive_sum( def inclusive_sum(
inputs: Tensor, packed_info: Optional[Tensor] = None inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Inclusive Sum that supports flattened tensor. """Inclusive Sum that supports flattened tensor.
...@@ -20,11 +24,12 @@ def inclusive_sum( ...@@ -20,11 +24,12 @@ def inclusive_sum(
Args: Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified. tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks. of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None. along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns: Returns:
The inclusive sum with the same shape as the input tensor. The inclusive sum with the same shape as the input tensor.
...@@ -39,22 +44,43 @@ def inclusive_sum( ...@@ -39,22 +44,43 @@ def inclusive_sum(
tensor([ 1., 3., 3., 7., 12., 6., 13., 21., 30.], device='cuda:0') tensor([ 1., 3., 3., 7., 12., 6., 13., 21., 30.], device='cuda:0')
""" """
if packed_info is None: if indices is not None and packed_info is not None:
# Batched inclusive sum on the last dimension. raise ValueError(
outputs = torch.cumsum(inputs, dim=-1) "Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _InclusiveSumCUB.apply(indices, inputs)
else: else:
# Flattened inclusive sum. warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened." assert inputs.dim() == 1, "inputs must be flattened."
assert ( assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2 packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)." ), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1) chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _InclusiveSum.apply(chunk_starts, chunk_cnts, inputs, False) outputs = _InclusiveSum.apply(chunk_starts, chunk_cnts, inputs, False)
if indices is None and packed_info is None:
# Batched inclusive sum on the last dimension.
outputs = torch.cumsum(inputs, dim=-1)
return outputs return outputs
def exclusive_sum( def exclusive_sum(
inputs: Tensor, packed_info: Optional[Tensor] = None inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Exclusive Sum that supports flattened tensor. """Exclusive Sum that supports flattened tensor.
...@@ -62,11 +88,12 @@ def exclusive_sum( ...@@ -62,11 +88,12 @@ def exclusive_sum(
Args: Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified. tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks. of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None. along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns: Returns:
The exclusive sum with the same shape as the input tensor. The exclusive sum with the same shape as the input tensor.
...@@ -81,27 +108,47 @@ def exclusive_sum( ...@@ -81,27 +108,47 @@ def exclusive_sum(
tensor([ 0., 1., 0., 3., 7., 0., 6., 13., 21.], device='cuda:0') tensor([ 0., 1., 0., 3., 7., 0., 6., 13., 21.], device='cuda:0')
""" """
if packed_info is None: if indices is not None and packed_info is not None:
# Batched exclusive sum on the last dimension. raise ValueError(
outputs = torch.cumsum( "Only one of `indices` and `packed_info` can be specified."
torch.cat(
[torch.zeros_like(inputs[..., :1]), inputs[..., :-1]], dim=-1
),
dim=-1,
) )
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _ExclusiveSumCUB.apply(indices, inputs)
else: else:
# Flattened exclusive sum. warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened." assert inputs.dim() == 1, "inputs must be flattened."
assert ( assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2 packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)." ), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1) chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _ExclusiveSum.apply(chunk_starts, chunk_cnts, inputs, False) outputs = _ExclusiveSum.apply(chunk_starts, chunk_cnts, inputs, False)
if indices is None and packed_info is None:
# Batched exclusive sum on the last dimension.
outputs = torch.cumsum(
torch.cat(
[torch.zeros_like(inputs[..., :1]), inputs[..., :-1]], dim=-1
),
dim=-1,
)
return outputs return outputs
def inclusive_prod( def inclusive_prod(
inputs: Tensor, packed_info: Optional[Tensor] = None inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Inclusive Product that supports flattened tensor. """Inclusive Product that supports flattened tensor.
...@@ -111,11 +158,12 @@ def inclusive_prod( ...@@ -111,11 +158,12 @@ def inclusive_prod(
Args: Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified. tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks. of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None. along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns: Returns:
The inclusive product with the same shape as the input tensor. The inclusive product with the same shape as the input tensor.
...@@ -130,22 +178,43 @@ def inclusive_prod( ...@@ -130,22 +178,43 @@ def inclusive_prod(
tensor([1., 2., 3., 12., 60., 6., 42., 336., 3024.], device='cuda:0') tensor([1., 2., 3., 12., 60., 6., 42., 336., 3024.], device='cuda:0')
""" """
if packed_info is None: if indices is not None and packed_info is not None:
# Batched inclusive product on the last dimension. raise ValueError(
outputs = torch.cumprod(inputs, dim=-1) "Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _InclusiveProdCUB.apply(indices, inputs)
else: else:
# Flattened inclusive product. warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened." assert inputs.dim() == 1, "inputs must be flattened."
assert ( assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2 packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)." ), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1) chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _InclusiveProd.apply(chunk_starts, chunk_cnts, inputs) outputs = _InclusiveProd.apply(chunk_starts, chunk_cnts, inputs)
if indices is None and packed_info is None:
# Batched inclusive product on the last dimension.
outputs = torch.cumprod(inputs, dim=-1)
return outputs return outputs
def exclusive_prod( def exclusive_prod(
inputs: Tensor, packed_info: Optional[Tensor] = None inputs: Tensor,
packed_info: Optional[Tensor] = None,
indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
"""Exclusive Product that supports flattened tensor. """Exclusive Product that supports flattened tensor.
...@@ -153,11 +222,12 @@ def exclusive_prod( ...@@ -153,11 +222,12 @@ def exclusive_prod(
Args: Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified. tensor with either `packed_info` or `indices` specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks. of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None. along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns: Returns:
The exclusive product with the same shape as the input tensor. The exclusive product with the same shape as the input tensor.
...@@ -173,16 +243,42 @@ def exclusive_prod( ...@@ -173,16 +243,42 @@ def exclusive_prod(
tensor([1., 1., 1., 3., 12., 1., 6., 42., 336.], device='cuda:0') tensor([1., 1., 1., 3., 12., 1., 6., 42., 336.], device='cuda:0')
""" """
if packed_info is None:
if indices is not None and packed_info is not None:
raise ValueError(
"Only one of `indices` and `packed_info` can be specified."
)
if indices is not None:
assert (
indices.dim() == 1 and indices.shape == inputs.shape
), "indices must be 1-D with the same shape as inputs."
if _C.is_cub_available():
# Use CUB if available
outputs = _ExclusiveProdCUB.apply(indices, inputs)
else:
warnings.warn(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info = pack_info(ray_indices=indices)
if packed_info is not None:
assert inputs.dim() == 1, "inputs must be flattened."
assert (
packed_info.dim() == 2 and packed_info.shape[-1] == 2
), "packed_info must be 2-D with shape (B, 2)."
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _ExclusiveProd.apply(chunk_starts, chunk_cnts, inputs)
if indices is None and packed_info is None:
# Batched exclusive product on the last dimension.
outputs = torch.cumprod( outputs = torch.cumprod(
torch.cat( torch.cat(
[torch.ones_like(inputs[..., :1]), inputs[..., :-1]], dim=-1 [torch.ones_like(inputs[..., :1]), inputs[..., :-1]], dim=-1
), ),
dim=-1, dim=-1,
) )
else:
chunk_starts, chunk_cnts = packed_info.unbind(dim=-1)
outputs = _ExclusiveProd.apply(chunk_starts, chunk_cnts, inputs)
return outputs return outputs
...@@ -286,3 +382,87 @@ class _ExclusiveProd(torch.autograd.Function): ...@@ -286,3 +382,87 @@ class _ExclusiveProd(torch.autograd.Function):
chunk_starts, chunk_cnts, inputs, outputs, grad_outputs chunk_starts, chunk_cnts, inputs, outputs, grad_outputs
) )
return None, None, grad_inputs return None, None, grad_inputs
class _InclusiveSumCUB(torch.autograd.Function):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.inclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _ExclusiveSumCUB(torch.autograd.Function):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.exclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _InclusiveProdCUB(torch.autograd.Function):
"""Inclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.inclusive_prod_cub_backward(
indices, inputs, outputs, grad_outputs
)
return None, grad_inputs
class _ExclusiveProdCUB(torch.autograd.Function):
"""Exclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.exclusive_prod_cub_backward(
indices, inputs, outputs, grad_outputs
)
return None, grad_inputs
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import torch
from torch import Tensor
from . import cuda as _C
def inclusive_sum_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Inclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _InclusiveSum.apply(indices, inputs)
return outputs
def exclusive_sum_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Exclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _ExclusiveSum.apply(indices, inputs)
return outputs
def inclusive_prod_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Inclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _InclusiveProd.apply(indices, inputs)
return outputs
def exclusive_prod_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Exclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _ExclusiveProd.apply(indices, inputs)
return outputs
class _InclusiveSum(torch.autograd.Function):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.inclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _ExclusiveSum(torch.autograd.Function):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.exclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _InclusiveProd(torch.autograd.Function):
"""Inclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.inclusive_prod_cub_backward(indices, inputs, outputs, grad_outputs)
return None, grad_inputs
class _ExclusiveProd(torch.autograd.Function):
"""Exclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.exclusive_prod_cub_backward(indices, inputs, outputs, grad_outputs)
return None, grad_inputs
...@@ -7,7 +7,6 @@ device = "cuda:0" ...@@ -7,7 +7,6 @@ device = "cuda:0"
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_sum(): def test_inclusive_sum():
from nerfacc.scan import inclusive_sum from nerfacc.scan import inclusive_sum
from nerfacc.scan_cub import inclusive_sum_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -34,7 +33,7 @@ def test_inclusive_sum(): ...@@ -34,7 +33,7 @@ def test_inclusive_sum():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long) indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1]) indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten() indices = indices.flatten()
outputs3 = inclusive_sum_cub(flatten_data, indices) outputs3 = inclusive_sum(flatten_data, indices=indices)
outputs3.sum().backward() outputs3.sum().backward()
grad3 = data.grad.clone() grad3 = data.grad.clone()
data.grad.zero_() data.grad.zero_()
...@@ -49,7 +48,6 @@ def test_inclusive_sum(): ...@@ -49,7 +48,6 @@ def test_inclusive_sum():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_sum(): def test_exclusive_sum():
from nerfacc.scan import exclusive_sum from nerfacc.scan import exclusive_sum
from nerfacc.scan_cub import exclusive_sum_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -76,7 +74,7 @@ def test_exclusive_sum(): ...@@ -76,7 +74,7 @@ def test_exclusive_sum():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long) indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1]) indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten() indices = indices.flatten()
outputs3 = exclusive_sum_cub(flatten_data, indices) outputs3 = exclusive_sum(flatten_data, indices=indices)
outputs3.sum().backward() outputs3.sum().backward()
grad3 = data.grad.clone() grad3 = data.grad.clone()
data.grad.zero_() data.grad.zero_()
...@@ -93,7 +91,6 @@ def test_exclusive_sum(): ...@@ -93,7 +91,6 @@ def test_exclusive_sum():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_prod(): def test_inclusive_prod():
from nerfacc.scan import inclusive_prod from nerfacc.scan import inclusive_prod
from nerfacc.scan_cub import inclusive_prod_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -120,7 +117,7 @@ def test_inclusive_prod(): ...@@ -120,7 +117,7 @@ def test_inclusive_prod():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long) indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1]) indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten() indices = indices.flatten()
outputs3 = inclusive_prod_cub(flatten_data, indices) outputs3 = inclusive_prod(flatten_data, indices=indices)
outputs3.sum().backward() outputs3.sum().backward()
grad3 = data.grad.clone() grad3 = data.grad.clone()
data.grad.zero_() data.grad.zero_()
...@@ -135,7 +132,6 @@ def test_inclusive_prod(): ...@@ -135,7 +132,6 @@ def test_inclusive_prod():
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_prod(): def test_exclusive_prod():
from nerfacc.scan import exclusive_prod from nerfacc.scan import exclusive_prod
from nerfacc.scan_cub import exclusive_prod_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -162,7 +158,7 @@ def test_exclusive_prod(): ...@@ -162,7 +158,7 @@ def test_exclusive_prod():
indices = torch.arange(data.shape[0], device=device, dtype=torch.long) indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1]) indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten() indices = indices.flatten()
outputs3 = exclusive_prod_cub(flatten_data, indices) outputs3 = exclusive_prod(flatten_data, indices=indices)
outputs3.sum().backward() outputs3.sum().backward()
grad3 = data.grad.clone() grad3 = data.grad.clone()
data.grad.zero_() data.grad.zero_()
...@@ -175,10 +171,11 @@ def test_exclusive_prod(): ...@@ -175,10 +171,11 @@ def test_exclusive_prod():
assert torch.allclose(outputs1, outputs3) assert torch.allclose(outputs1, outputs3)
assert torch.allclose(grad1, grad3) assert torch.allclose(grad1, grad3)
def profile(): def profile():
import tqdm import tqdm
from nerfacc.scan import inclusive_sum from nerfacc.scan import inclusive_sum
from nerfacc.scan_cub import inclusive_sum_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -202,7 +199,7 @@ def profile(): ...@@ -202,7 +199,7 @@ def profile():
indices = indices.flatten() indices = indices.flatten()
torch.cuda.synchronize() torch.cuda.synchronize()
for _ in tqdm.trange(2000): for _ in tqdm.trange(2000):
outputs3 = inclusive_sum_cub(flatten_data, indices) outputs3 = inclusive_sum(flatten_data, indices=indices)
outputs3.sum().backward() outputs3.sum().backward()
...@@ -211,4 +208,4 @@ if __name__ == "__main__": ...@@ -211,4 +208,4 @@ if __name__ == "__main__":
test_exclusive_sum() test_exclusive_sum()
test_inclusive_prod() test_inclusive_prod()
test_exclusive_prod() test_exclusive_prod()
# profile() profile()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment