"tests/python/vscode:/vscode.git/clone" did not exist on "19096c6a8e7f1fb6f97bd2b43d1e9bde80a7a47f"
Commit 75a7b021 authored by Ruilong Li's avatar Ruilong Li
Browse files

cub scan added, past test

parent 6591dd38
...@@ -19,6 +19,12 @@ from .volrend import ( ...@@ -19,6 +19,12 @@ from .volrend import (
render_weight_from_density, render_weight_from_density,
rendering, rendering,
) )
from .scan_cub import (
exclusive_prod_cub,
exclusive_sum_cub,
inclusive_prod_cub,
inclusive_sum_cub,
)
__all__ = [ __all__ = [
"__version__", "__version__",
...@@ -26,6 +32,10 @@ __all__ = [ ...@@ -26,6 +32,10 @@ __all__ = [
"exclusive_prod", "exclusive_prod",
"inclusive_sum", "inclusive_sum",
"exclusive_sum", "exclusive_sum",
"inclusive_prod_cub",
"exclusive_prod_cub",
"inclusive_sum_cub",
"exclusive_sum_cub",
"pack_info", "pack_info",
"render_visibility_from_alpha", "render_visibility_from_alpha",
"render_visibility_from_density", "render_visibility_from_density",
......
...@@ -30,6 +30,13 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward") ...@@ -30,6 +30,13 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
exclusive_prod_forward = _make_lazy_cuda_func("exclusive_prod_forward") exclusive_prod_forward = _make_lazy_cuda_func("exclusive_prod_forward")
exclusive_prod_backward = _make_lazy_cuda_func("exclusive_prod_backward") exclusive_prod_backward = _make_lazy_cuda_func("exclusive_prod_backward")
inclusive_sum_cub = _make_lazy_cuda_func("inclusive_sum_cub")
exclusive_sum_cub = _make_lazy_cuda_func("exclusive_sum_cub")
inclusive_prod_cub_forward = _make_lazy_cuda_func("inclusive_prod_cub_forward")
inclusive_prod_cub_backward = _make_lazy_cuda_func("inclusive_prod_cub_backward")
exclusive_prod_cub_forward = _make_lazy_cuda_func("exclusive_prod_cub_forward")
exclusive_prod_cub_backward = _make_lazy_cuda_func("exclusive_prod_cub_backward")
# pdf # pdf
importance_sampling = _make_lazy_cuda_func("importance_sampling") importance_sampling = _make_lazy_cuda_func("importance_sampling")
searchsorted = _make_lazy_cuda_func("searchsorted") searchsorted = _make_lazy_cuda_func("searchsorted")
......
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
* Modified from aten/src/ATen/cuda/cub_definitions.cuh in PyTorch.
*/
#pragma once
#include <cuda.h> // for CUDA_VERSION
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include <cub/version.cuh>
#else
#define CUB_VERSION 0
#endif
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
\ No newline at end of file
...@@ -38,6 +38,31 @@ torch::Tensor exclusive_prod_backward( ...@@ -38,6 +38,31 @@ torch::Tensor exclusive_prod_backward(
torch::Tensor outputs, torch::Tensor outputs,
torch::Tensor grad_outputs); torch::Tensor grad_outputs);
torch::Tensor inclusive_sum_cub(
torch::Tensor ray_indices,
torch::Tensor inputs,
bool backward);
torch::Tensor exclusive_sum_cub(
torch::Tensor indices,
torch::Tensor inputs,
bool backward);
torch::Tensor inclusive_prod_cub_forward(
torch::Tensor indices,
torch::Tensor inputs);
torch::Tensor inclusive_prod_cub_backward(
torch::Tensor indices,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
torch::Tensor exclusive_prod_cub_forward(
torch::Tensor indices,
torch::Tensor inputs);
torch::Tensor exclusive_prod_cub_backward(
torch::Tensor indices,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs);
// grid // grid
std::vector<torch::Tensor> ray_aabb_intersect( std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3] const torch::Tensor rays_o, // [n_rays, 3]
...@@ -106,6 +131,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -106,6 +131,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC(exclusive_prod_forward); _REG_FUNC(exclusive_prod_forward);
_REG_FUNC(exclusive_prod_backward); _REG_FUNC(exclusive_prod_backward);
_REG_FUNC(inclusive_sum_cub);
_REG_FUNC(exclusive_sum_cub);
_REG_FUNC(inclusive_prod_cub_forward);
_REG_FUNC(inclusive_prod_cub_backward);
_REG_FUNC(exclusive_prod_cub_forward);
_REG_FUNC(exclusive_prod_cub_backward);
_REG_FUNC(ray_aabb_intersect); _REG_FUNC(ray_aabb_intersect);
_REG_FUNC(traverse_grids); _REG_FUNC(traverse_grids);
_REG_FUNC(searchsorted); _REG_FUNC(searchsorted);
......
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <thrust/iterator/reverse_iterator.h>
#include "include/utils_cuda.cuh"
#include "include/utils.cub.cuh"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
struct Product
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; }
};
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub InclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void inclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
"cub InclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey, keys, input, output, Product(),
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
torch::Tensor inclusive_sum_cub(
torch::Tensor indices,
torch::Tensor inputs,
bool backward)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(indices);
CHECK_INPUT(inputs);
TORCH_CHECK(indices.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(indices.size(0) == inputs.size(0));
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) {
inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges);
} else {
inclusive_sum_by_key(
indices.data_ptr<long>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
}
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor exclusive_sum_cub(
torch::Tensor indices,
torch::Tensor inputs,
bool backward)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(indices);
CHECK_INPUT(inputs);
TORCH_CHECK(indices.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(indices.size(0) == inputs.size(0));
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) {
exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges);
} else {
exclusive_sum_by_key(
indices.data_ptr<long>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
}
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_prod_cub_forward(
torch::Tensor indices,
torch::Tensor inputs)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(indices);
CHECK_INPUT(inputs);
TORCH_CHECK(indices.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(indices.size(0) == inputs.size(0));
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_prod_by_key(
indices.data_ptr<long>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_prod_cub_backward(
torch::Tensor indices,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs)
{
DEVICE_GUARD(grad_outputs);
CHECK_INPUT(indices);
CHECK_INPUT(grad_outputs);
TORCH_CHECK(indices.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(indices.size(0) == inputs.size(0));
int64_t n_edges = inputs.size(0);
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
if (n_edges == 0) {
return grad_inputs;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs = grad_inputs / inputs.clamp_min(1e-10f);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return grad_inputs;
}
torch::Tensor exclusive_prod_cub_forward(
torch::Tensor indices,
torch::Tensor inputs)
{
DEVICE_GUARD(inputs);
CHECK_INPUT(indices);
CHECK_INPUT(inputs);
TORCH_CHECK(indices.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(indices.size(0) == inputs.size(0));
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
indices.data_ptr<long>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor exclusive_prod_cub_backward(
torch::Tensor indices,
torch::Tensor inputs,
torch::Tensor outputs,
torch::Tensor grad_outputs)
{
DEVICE_GUARD(grad_outputs);
CHECK_INPUT(indices);
CHECK_INPUT(grad_outputs);
TORCH_CHECK(indices.ndimension() == 1);
TORCH_CHECK(inputs.ndimension() == 1);
TORCH_CHECK(indices.size(0) == inputs.size(0));
int64_t n_edges = inputs.size(0);
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
if (n_edges == 0) {
return grad_inputs;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs = grad_inputs / inputs.clamp_min(1e-10f);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return grad_inputs;
}
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import torch
from torch import Tensor
from . import cuda as _C
def inclusive_sum_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Inclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _InclusiveSum.apply(indices, inputs)
return outputs
def exclusive_sum_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Exclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _ExclusiveSum.apply(indices, inputs)
return outputs
def inclusive_prod_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Inclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _InclusiveProd.apply(indices, inputs)
return outputs
def exclusive_prod_cub(inputs: Tensor, indices: Tensor) -> Tensor:
"""Exclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert inputs.dim() == 1, "inputs must be flattened."
assert (
indices.dim() == 1 and indices.shape[0] == inputs.shape[0]
), "indices must be 1-D with the same shape as inputs."
outputs = _ExclusiveProd.apply(indices, inputs)
return outputs
class _InclusiveSum(torch.autograd.Function):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.inclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _ExclusiveSum(torch.autograd.Function):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_sum_cub(indices, inputs, False)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
(indices,) = ctx.saved_tensors
grad_inputs = _C.exclusive_sum_cub(indices, grad_outputs, True)
return None, grad_inputs
class _InclusiveProd(torch.autograd.Function):
"""Inclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.inclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.inclusive_prod_cub_backward(indices, inputs, outputs, grad_outputs)
return None, grad_inputs
class _ExclusiveProd(torch.autograd.Function):
"""Exclusive Product on a Flattened Tensor with CUB."""
@staticmethod
def forward(ctx, indices, inputs):
indices = indices.contiguous()
inputs = inputs.contiguous()
outputs = _C.exclusive_prod_cub_forward(indices, inputs)
if ctx.needs_input_grad[1]:
ctx.save_for_backward(indices, inputs, outputs)
return outputs
@staticmethod
def backward(ctx, grad_outputs):
grad_outputs = grad_outputs.contiguous()
indices, inputs, outputs = ctx.saved_tensors
grad_inputs = _C.exclusive_prod_cub_backward(indices, inputs, outputs, grad_outputs)
return None, grad_inputs
...@@ -7,6 +7,7 @@ device = "cuda:0" ...@@ -7,6 +7,7 @@ device = "cuda:0"
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_sum(): def test_inclusive_sum():
from nerfacc.scan import inclusive_sum from nerfacc.scan import inclusive_sum
from nerfacc.scan_cub import inclusive_sum_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -28,14 +29,27 @@ def test_inclusive_sum(): ...@@ -28,14 +29,27 @@ def test_inclusive_sum():
outputs2 = inclusive_sum(flatten_data, packed_info=packed_info) outputs2 = inclusive_sum(flatten_data, packed_info=packed_info)
outputs2.sum().backward() outputs2.sum().backward()
grad2 = data.grad.clone() grad2 = data.grad.clone()
data.grad.zero_()
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = inclusive_sum_cub(flatten_data, indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
assert torch.allclose(outputs1, outputs2) assert torch.allclose(outputs1, outputs2)
assert torch.allclose(grad1, grad2) assert torch.allclose(grad1, grad2)
assert torch.allclose(outputs1, outputs3)
assert torch.allclose(grad1, grad3)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_sum(): def test_exclusive_sum():
from nerfacc.scan import exclusive_sum from nerfacc.scan import exclusive_sum
from nerfacc.scan_cub import exclusive_sum_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -57,16 +71,29 @@ def test_exclusive_sum(): ...@@ -57,16 +71,29 @@ def test_exclusive_sum():
outputs2 = exclusive_sum(flatten_data, packed_info=packed_info) outputs2 = exclusive_sum(flatten_data, packed_info=packed_info)
outputs2.sum().backward() outputs2.sum().backward()
grad2 = data.grad.clone() grad2 = data.grad.clone()
data.grad.zero_()
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = exclusive_sum_cub(flatten_data, indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
# TODO: check exclusive sum. numeric error? # TODO: check exclusive sum. numeric error?
# print((outputs1 - outputs2).abs().max()) # 0.0002 # print((outputs1 - outputs2).abs().max()) # 0.0002
assert torch.allclose(outputs1, outputs2, atol=3e-4) assert torch.allclose(outputs1, outputs2, atol=3e-4)
assert torch.allclose(grad1, grad2) assert torch.allclose(grad1, grad2)
assert torch.allclose(outputs1, outputs3, atol=3e-4)
assert torch.allclose(grad1, grad3)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_inclusive_prod(): def test_inclusive_prod():
from nerfacc.scan import inclusive_prod from nerfacc.scan import inclusive_prod
from nerfacc.scan_cub import inclusive_prod_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -88,14 +115,27 @@ def test_inclusive_prod(): ...@@ -88,14 +115,27 @@ def test_inclusive_prod():
outputs2 = inclusive_prod(flatten_data, packed_info=packed_info) outputs2 = inclusive_prod(flatten_data, packed_info=packed_info)
outputs2.sum().backward() outputs2.sum().backward()
grad2 = data.grad.clone() grad2 = data.grad.clone()
data.grad.zero_()
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = inclusive_prod_cub(flatten_data, indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
assert torch.allclose(outputs1, outputs2) assert torch.allclose(outputs1, outputs2)
assert torch.allclose(grad1, grad2) assert torch.allclose(grad1, grad2)
assert torch.allclose(outputs1, outputs3)
assert torch.allclose(grad1, grad3)
@pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device") @pytest.mark.skipif(not torch.cuda.is_available, reason="No CUDA device")
def test_exclusive_prod(): def test_exclusive_prod():
from nerfacc.scan import exclusive_prod from nerfacc.scan import exclusive_prod
from nerfacc.scan_cub import exclusive_prod_cub
torch.manual_seed(42) torch.manual_seed(42)
...@@ -117,15 +157,27 @@ def test_exclusive_prod(): ...@@ -117,15 +157,27 @@ def test_exclusive_prod():
outputs2 = exclusive_prod(flatten_data, packed_info=packed_info) outputs2 = exclusive_prod(flatten_data, packed_info=packed_info)
outputs2.sum().backward() outputs2.sum().backward()
grad2 = data.grad.clone() grad2 = data.grad.clone()
data.grad.zero_()
indices = torch.arange(data.shape[0], device=device, dtype=torch.long)
indices = indices.repeat_interleave(data.shape[1])
indices = indices.flatten()
outputs3 = exclusive_prod_cub(flatten_data, indices)
outputs3.sum().backward()
grad3 = data.grad.clone()
data.grad.zero_()
# TODO: check exclusive sum. numeric error? # TODO: check exclusive sum. numeric error?
# print((outputs1 - outputs2).abs().max()) # print((outputs1 - outputs2).abs().max())
assert torch.allclose(outputs1, outputs2) assert torch.allclose(outputs1, outputs2)
assert torch.allclose(grad1, grad2) assert torch.allclose(grad1, grad2)
assert torch.allclose(outputs1, outputs3)
assert torch.allclose(grad1, grad3)
if __name__ == "__main__": if __name__ == "__main__":
test_inclusive_sum() test_inclusive_sum()
test_exclusive_sum() test_exclusive_sum()
test_inclusive_prod() test_inclusive_prod()
test_exclusive_prod() test_exclusive_prod()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment