Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
75a7b021
"tests/python/vscode:/vscode.git/clone" did not exist on "19096c6a8e7f1fb6f97bd2b43d1e9bde80a7a47f"
Commit
75a7b021
authored
Nov 23, 2023
by
Ruilong Li
Browse files
cub scan added, past test
parent
6591dd38
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
541 additions
and
1 deletion
+541
-1
nerfacc/__init__.py
nerfacc/__init__.py
+10
-0
nerfacc/cuda/__init__.py
nerfacc/cuda/__init__.py
+7
-0
nerfacc/cuda/csrc/include/utils.cub.cuh
nerfacc/cuda/csrc/include/utils.cub.cuh
+32
-0
nerfacc/cuda/csrc/nerfacc.cpp
nerfacc/cuda/csrc/nerfacc.cpp
+32
-0
nerfacc/cuda/csrc/scan_cub.cu
nerfacc/cuda/csrc/scan_cub.cu
+279
-0
nerfacc/scan_cub.py
nerfacc/scan_cub.py
+128
-0
tests/test_scan.py
tests/test_scan.py
+53
-1
No files found.
nerfacc/__init__.py
View file @
75a7b021
...
@@ -19,6 +19,12 @@ from .volrend import (
...
@@ -19,6 +19,12 @@ from .volrend import (
render_weight_from_density
,
render_weight_from_density
,
rendering
,
rendering
,
)
)
from
.scan_cub
import
(
exclusive_prod_cub
,
exclusive_sum_cub
,
inclusive_prod_cub
,
inclusive_sum_cub
,
)
__all__
=
[
__all__
=
[
"__version__"
,
"__version__"
,
...
@@ -26,6 +32,10 @@ __all__ = [
...
@@ -26,6 +32,10 @@ __all__ = [
"exclusive_prod"
,
"exclusive_prod"
,
"inclusive_sum"
,
"inclusive_sum"
,
"exclusive_sum"
,
"exclusive_sum"
,
"inclusive_prod_cub"
,
"exclusive_prod_cub"
,
"inclusive_sum_cub"
,
"exclusive_sum_cub"
,
"pack_info"
,
"pack_info"
,
"render_visibility_from_alpha"
,
"render_visibility_from_alpha"
,
"render_visibility_from_density"
,
"render_visibility_from_density"
,
...
...
nerfacc/cuda/__init__.py
View file @
75a7b021
...
@@ -30,6 +30,13 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
...
@@ -30,6 +30,13 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
exclusive_prod_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_forward"
)
exclusive_prod_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_forward"
)
exclusive_prod_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_backward"
)
exclusive_prod_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_backward"
)
inclusive_sum_cub
=
_make_lazy_cuda_func
(
"inclusive_sum_cub"
)
exclusive_sum_cub
=
_make_lazy_cuda_func
(
"exclusive_sum_cub"
)
inclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_forward"
)
inclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_backward"
)
exclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_forward"
)
exclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_backward"
)
# pdf
# pdf
importance_sampling
=
_make_lazy_cuda_func
(
"importance_sampling"
)
importance_sampling
=
_make_lazy_cuda_func
(
"importance_sampling"
)
searchsorted
=
_make_lazy_cuda_func
(
"searchsorted"
)
searchsorted
=
_make_lazy_cuda_func
(
"searchsorted"
)
...
...
nerfacc/cuda/csrc/include/utils.cub.cuh
0 → 100644
View file @
75a7b021
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
* Modified from aten/src/ATen/cuda/cub_definitions.cuh in PyTorch.
*/
#pragma once
#include <cuda.h> // for CUDA_VERSION
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
#include <cub/version.cuh>
#else
#define CUB_VERSION 0
#endif
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
\ No newline at end of file
nerfacc/cuda/csrc/nerfacc.cpp
View file @
75a7b021
...
@@ -38,6 +38,31 @@ torch::Tensor exclusive_prod_backward(
...
@@ -38,6 +38,31 @@ torch::Tensor exclusive_prod_backward(
torch
::
Tensor
outputs
,
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
);
torch
::
Tensor
grad_outputs
);
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
ray_indices
,
torch
::
Tensor
inputs
,
bool
backward
);
torch
::
Tensor
exclusive_sum_cub
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
bool
backward
);
torch
::
Tensor
inclusive_prod_cub_forward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
);
torch
::
Tensor
inclusive_prod_cub_backward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
);
torch
::
Tensor
exclusive_prod_cub_forward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
);
torch
::
Tensor
exclusive_prod_cub_backward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
);
// grid
// grid
std
::
vector
<
torch
::
Tensor
>
ray_aabb_intersect
(
std
::
vector
<
torch
::
Tensor
>
ray_aabb_intersect
(
const
torch
::
Tensor
rays_o
,
// [n_rays, 3]
const
torch
::
Tensor
rays_o
,
// [n_rays, 3]
...
@@ -106,6 +131,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -106,6 +131,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC
(
exclusive_prod_forward
);
_REG_FUNC
(
exclusive_prod_forward
);
_REG_FUNC
(
exclusive_prod_backward
);
_REG_FUNC
(
exclusive_prod_backward
);
_REG_FUNC
(
inclusive_sum_cub
);
_REG_FUNC
(
exclusive_sum_cub
);
_REG_FUNC
(
inclusive_prod_cub_forward
);
_REG_FUNC
(
inclusive_prod_cub_backward
);
_REG_FUNC
(
exclusive_prod_cub_forward
);
_REG_FUNC
(
exclusive_prod_cub_backward
);
_REG_FUNC
(
ray_aabb_intersect
);
_REG_FUNC
(
ray_aabb_intersect
);
_REG_FUNC
(
traverse_grids
);
_REG_FUNC
(
traverse_grids
);
_REG_FUNC
(
searchsorted
);
_REG_FUNC
(
searchsorted
);
...
...
nerfacc/cuda/csrc/scan_cub.cu
0 → 100644
View file @
75a7b021
/*
* Copyright (c) 2022 Ruilong Li, UC Berkeley.
*/
#include <thrust/iterator/reverse_iterator.h>
#include "include/utils_cuda.cuh"
#include "include/utils.cub.cuh"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
struct
Product
{
template
<
typename
T
>
__host__
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
*
b
;
}
};
template
<
typename
KeysInputIteratorT
,
typename
ValuesInputIteratorT
,
typename
ValuesOutputIteratorT
>
inline
void
exclusive_sum_by_key
(
KeysInputIteratorT
keys
,
ValuesInputIteratorT
input
,
ValuesOutputIteratorT
output
,
int64_t
num_items
)
{
TORCH_CHECK
(
num_items
<=
std
::
numeric_limits
<
long
>::
max
(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements"
);
CUB_WRAPPER
(
cub
::
DeviceScan
::
ExclusiveSumByKey
,
keys
,
input
,
output
,
num_items
,
cub
::
Equality
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
template
<
typename
KeysInputIteratorT
,
typename
ValuesInputIteratorT
,
typename
ValuesOutputIteratorT
>
inline
void
inclusive_sum_by_key
(
KeysInputIteratorT
keys
,
ValuesInputIteratorT
input
,
ValuesOutputIteratorT
output
,
int64_t
num_items
)
{
TORCH_CHECK
(
num_items
<=
std
::
numeric_limits
<
long
>::
max
(),
"cub InclusiveSumByKey does not support more than LONG_MAX elements"
);
CUB_WRAPPER
(
cub
::
DeviceScan
::
InclusiveSumByKey
,
keys
,
input
,
output
,
num_items
,
cub
::
Equality
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
template
<
typename
KeysInputIteratorT
,
typename
ValuesInputIteratorT
,
typename
ValuesOutputIteratorT
>
inline
void
exclusive_prod_by_key
(
KeysInputIteratorT
keys
,
ValuesInputIteratorT
input
,
ValuesOutputIteratorT
output
,
int64_t
num_items
)
{
TORCH_CHECK
(
num_items
<=
std
::
numeric_limits
<
long
>::
max
(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements"
);
CUB_WRAPPER
(
cub
::
DeviceScan
::
ExclusiveScanByKey
,
keys
,
input
,
output
,
Product
(),
1.0
f
,
num_items
,
cub
::
Equality
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
template
<
typename
KeysInputIteratorT
,
typename
ValuesInputIteratorT
,
typename
ValuesOutputIteratorT
>
inline
void
inclusive_prod_by_key
(
KeysInputIteratorT
keys
,
ValuesInputIteratorT
input
,
ValuesOutputIteratorT
output
,
int64_t
num_items
)
{
TORCH_CHECK
(
num_items
<=
std
::
numeric_limits
<
long
>::
max
(),
"cub InclusiveScanByKey does not support more than LONG_MAX elements"
);
CUB_WRAPPER
(
cub
::
DeviceScan
::
InclusiveScanByKey
,
keys
,
input
,
output
,
Product
(),
num_items
,
cub
::
Equality
(),
at
::
cuda
::
getCurrentCUDAStream
());
}
#endif
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
bool
backward
)
{
DEVICE_GUARD
(
inputs
);
CHECK_INPUT
(
indices
);
CHECK_INPUT
(
inputs
);
TORCH_CHECK
(
indices
.
ndimension
()
==
1
);
TORCH_CHECK
(
inputs
.
ndimension
()
==
1
);
TORCH_CHECK
(
indices
.
size
(
0
)
==
inputs
.
size
(
0
));
int64_t
n_edges
=
inputs
.
size
(
0
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
);
if
(
n_edges
==
0
)
{
return
outputs
;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
if
(
backward
)
{
inclusive_sum_by_key
(
thrust
::
make_reverse_iterator
(
indices
.
data_ptr
<
long
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
(
inputs
.
data_ptr
<
float
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
(
outputs
.
data_ptr
<
float
>
()
+
n_edges
),
n_edges
);
}
else
{
inclusive_sum_by_key
(
indices
.
data_ptr
<
long
>
(),
inputs
.
data_ptr
<
float
>
(),
outputs
.
data_ptr
<
float
>
(),
n_edges
);
}
#else
std
::
runtime_error
(
"CUB functions are only supported in CUDA >= 11.6."
);
#endif
cudaGetLastError
();
return
outputs
;
}
torch
::
Tensor
exclusive_sum_cub
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
bool
backward
)
{
DEVICE_GUARD
(
inputs
);
CHECK_INPUT
(
indices
);
CHECK_INPUT
(
inputs
);
TORCH_CHECK
(
indices
.
ndimension
()
==
1
);
TORCH_CHECK
(
inputs
.
ndimension
()
==
1
);
TORCH_CHECK
(
indices
.
size
(
0
)
==
inputs
.
size
(
0
));
int64_t
n_edges
=
inputs
.
size
(
0
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
);
if
(
n_edges
==
0
)
{
return
outputs
;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
if
(
backward
)
{
exclusive_sum_by_key
(
thrust
::
make_reverse_iterator
(
indices
.
data_ptr
<
long
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
(
inputs
.
data_ptr
<
float
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
(
outputs
.
data_ptr
<
float
>
()
+
n_edges
),
n_edges
);
}
else
{
exclusive_sum_by_key
(
indices
.
data_ptr
<
long
>
(),
inputs
.
data_ptr
<
float
>
(),
outputs
.
data_ptr
<
float
>
(),
n_edges
);
}
#else
std
::
runtime_error
(
"CUB functions are only supported in CUDA >= 11.6."
);
#endif
cudaGetLastError
();
return
outputs
;
}
torch
::
Tensor
inclusive_prod_cub_forward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
)
{
DEVICE_GUARD
(
inputs
);
CHECK_INPUT
(
indices
);
CHECK_INPUT
(
inputs
);
TORCH_CHECK
(
indices
.
ndimension
()
==
1
);
TORCH_CHECK
(
inputs
.
ndimension
()
==
1
);
TORCH_CHECK
(
indices
.
size
(
0
)
==
inputs
.
size
(
0
));
int64_t
n_edges
=
inputs
.
size
(
0
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
);
if
(
n_edges
==
0
)
{
return
outputs
;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_prod_by_key
(
indices
.
data_ptr
<
long
>
(),
inputs
.
data_ptr
<
float
>
(),
outputs
.
data_ptr
<
float
>
(),
n_edges
);
#else
std
::
runtime_error
(
"CUB functions are only supported in CUDA >= 11.6."
);
#endif
cudaGetLastError
();
return
outputs
;
}
torch
::
Tensor
inclusive_prod_cub_backward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
)
{
DEVICE_GUARD
(
grad_outputs
);
CHECK_INPUT
(
indices
);
CHECK_INPUT
(
grad_outputs
);
TORCH_CHECK
(
indices
.
ndimension
()
==
1
);
TORCH_CHECK
(
inputs
.
ndimension
()
==
1
);
TORCH_CHECK
(
indices
.
size
(
0
)
==
inputs
.
size
(
0
));
int64_t
n_edges
=
inputs
.
size
(
0
);
torch
::
Tensor
grad_inputs
=
torch
::
empty_like
(
grad_outputs
);
if
(
n_edges
==
0
)
{
return
grad_inputs
;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_sum_by_key
(
thrust
::
make_reverse_iterator
(
indices
.
data_ptr
<
long
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
((
grad_outputs
*
outputs
).
data_ptr
<
float
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
(
grad_inputs
.
data_ptr
<
float
>
()
+
n_edges
),
n_edges
);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs
=
grad_inputs
/
inputs
.
clamp_min
(
1e-10
f
);
#else
std
::
runtime_error
(
"CUB functions are only supported in CUDA >= 11.6."
);
#endif
cudaGetLastError
();
return
grad_inputs
;
}
torch
::
Tensor
exclusive_prod_cub_forward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
)
{
DEVICE_GUARD
(
inputs
);
CHECK_INPUT
(
indices
);
CHECK_INPUT
(
inputs
);
TORCH_CHECK
(
indices
.
ndimension
()
==
1
);
TORCH_CHECK
(
inputs
.
ndimension
()
==
1
);
TORCH_CHECK
(
indices
.
size
(
0
)
==
inputs
.
size
(
0
));
int64_t
n_edges
=
inputs
.
size
(
0
);
torch
::
Tensor
outputs
=
torch
::
empty_like
(
inputs
);
if
(
n_edges
==
0
)
{
return
outputs
;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key
(
indices
.
data_ptr
<
long
>
(),
inputs
.
data_ptr
<
float
>
(),
outputs
.
data_ptr
<
float
>
(),
n_edges
);
#else
std
::
runtime_error
(
"CUB functions are only supported in CUDA >= 11.6."
);
#endif
cudaGetLastError
();
return
outputs
;
}
torch
::
Tensor
exclusive_prod_cub_backward
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
)
{
DEVICE_GUARD
(
grad_outputs
);
CHECK_INPUT
(
indices
);
CHECK_INPUT
(
grad_outputs
);
TORCH_CHECK
(
indices
.
ndimension
()
==
1
);
TORCH_CHECK
(
inputs
.
ndimension
()
==
1
);
TORCH_CHECK
(
indices
.
size
(
0
)
==
inputs
.
size
(
0
));
int64_t
n_edges
=
inputs
.
size
(
0
);
torch
::
Tensor
grad_inputs
=
torch
::
empty_like
(
grad_outputs
);
if
(
n_edges
==
0
)
{
return
grad_inputs
;
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key
(
thrust
::
make_reverse_iterator
(
indices
.
data_ptr
<
long
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
((
grad_outputs
*
outputs
).
data_ptr
<
float
>
()
+
n_edges
),
thrust
::
make_reverse_iterator
(
grad_inputs
.
data_ptr
<
float
>
()
+
n_edges
),
n_edges
);
// FIXME: the grad is not correct when inputs are zero!!
grad_inputs
=
grad_inputs
/
inputs
.
clamp_min
(
1e-10
f
);
#else
std
::
runtime_error
(
"CUB functions are only supported in CUDA >= 11.6."
);
#endif
cudaGetLastError
();
return
grad_inputs
;
}
nerfacc/scan_cub.py
0 → 100644
View file @
75a7b021
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import
torch
from
torch
import
Tensor
from
.
import
cuda
as
_C
def
inclusive_sum_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Inclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_InclusiveSum
.
apply
(
indices
,
inputs
)
return
outputs
def
exclusive_sum_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Exclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_ExclusiveSum
.
apply
(
indices
,
inputs
)
return
outputs
def
inclusive_prod_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Inclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_InclusiveProd
.
apply
(
indices
,
inputs
)
return
outputs
def
exclusive_prod_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Exclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_ExclusiveProd
.
apply
(
indices
,
inputs
)
return
outputs
class
_InclusiveSum
(
torch
.
autograd
.
Function
):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_ExclusiveSum
(
torch
.
autograd
.
Function
):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_InclusiveProd
(
torch
.
autograd
.
Function
):
"""Inclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
class
_ExclusiveProd
(
torch
.
autograd
.
Function
):
"""Exclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
tests/test_scan.py
View file @
75a7b021
...
@@ -7,6 +7,7 @@ device = "cuda:0"
...
@@ -7,6 +7,7 @@ device = "cuda:0"
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_inclusive_sum
():
def
test_inclusive_sum
():
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan_cub
import
inclusive_sum_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -28,14 +29,27 @@ def test_inclusive_sum():
...
@@ -28,14 +29,27 @@ def test_inclusive_sum():
outputs2
=
inclusive_sum
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
=
inclusive_sum
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
.
sum
().
backward
()
outputs2
.
sum
().
backward
()
grad2
=
data
.
grad
.
clone
()
grad2
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
inclusive_sum_cub
(
flatten_data
,
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
assert
torch
.
allclose
(
outputs1
,
outputs2
)
assert
torch
.
allclose
(
outputs1
,
outputs2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
outputs1
,
outputs3
)
assert
torch
.
allclose
(
grad1
,
grad3
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_exclusive_sum
():
def
test_exclusive_sum
():
from
nerfacc.scan
import
exclusive_sum
from
nerfacc.scan
import
exclusive_sum
from
nerfacc.scan_cub
import
exclusive_sum_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -57,16 +71,29 @@ def test_exclusive_sum():
...
@@ -57,16 +71,29 @@ def test_exclusive_sum():
outputs2
=
exclusive_sum
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
=
exclusive_sum
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
.
sum
().
backward
()
outputs2
.
sum
().
backward
()
grad2
=
data
.
grad
.
clone
()
grad2
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
exclusive_sum_cub
(
flatten_data
,
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
# TODO: check exclusive sum. numeric error?
# TODO: check exclusive sum. numeric error?
# print((outputs1 - outputs2).abs().max()) # 0.0002
# print((outputs1 - outputs2).abs().max()) # 0.0002
assert
torch
.
allclose
(
outputs1
,
outputs2
,
atol
=
3e-4
)
assert
torch
.
allclose
(
outputs1
,
outputs2
,
atol
=
3e-4
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
outputs1
,
outputs3
,
atol
=
3e-4
)
assert
torch
.
allclose
(
grad1
,
grad3
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_inclusive_prod
():
def
test_inclusive_prod
():
from
nerfacc.scan
import
inclusive_prod
from
nerfacc.scan
import
inclusive_prod
from
nerfacc.scan_cub
import
inclusive_prod_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -88,14 +115,27 @@ def test_inclusive_prod():
...
@@ -88,14 +115,27 @@ def test_inclusive_prod():
outputs2
=
inclusive_prod
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
=
inclusive_prod
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
.
sum
().
backward
()
outputs2
.
sum
().
backward
()
grad2
=
data
.
grad
.
clone
()
grad2
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
inclusive_prod_cub
(
flatten_data
,
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
assert
torch
.
allclose
(
outputs1
,
outputs2
)
assert
torch
.
allclose
(
outputs1
,
outputs2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
outputs1
,
outputs3
)
assert
torch
.
allclose
(
grad1
,
grad3
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_exclusive_prod
():
def
test_exclusive_prod
():
from
nerfacc.scan
import
exclusive_prod
from
nerfacc.scan
import
exclusive_prod
from
nerfacc.scan_cub
import
exclusive_prod_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -117,15 +157,27 @@ def test_exclusive_prod():
...
@@ -117,15 +157,27 @@ def test_exclusive_prod():
outputs2
=
exclusive_prod
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
=
exclusive_prod
(
flatten_data
,
packed_info
=
packed_info
)
outputs2
.
sum
().
backward
()
outputs2
.
sum
().
backward
()
grad2
=
data
.
grad
.
clone
()
grad2
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
exclusive_prod_cub
(
flatten_data
,
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
# TODO: check exclusive sum. numeric error?
# TODO: check exclusive sum. numeric error?
# print((outputs1 - outputs2).abs().max())
# print((outputs1 - outputs2).abs().max())
assert
torch
.
allclose
(
outputs1
,
outputs2
)
assert
torch
.
allclose
(
outputs1
,
outputs2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
grad1
,
grad2
)
assert
torch
.
allclose
(
outputs1
,
outputs3
)
assert
torch
.
allclose
(
grad1
,
grad3
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_inclusive_sum
()
test_inclusive_sum
()
test_exclusive_sum
()
test_exclusive_sum
()
test_inclusive_prod
()
test_inclusive_prod
()
test_exclusive_prod
()
test_exclusive_prod
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment