Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
ea07af8e
Commit
ea07af8e
authored
Nov 29, 2023
by
Ruilong Li
Browse files
inclusive sum in one function
parent
a0792e88
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
236 additions
and
178 deletions
+236
-178
nerfacc/__init__.py
nerfacc/__init__.py
+0
-6
nerfacc/cuda/__init__.py
nerfacc/cuda/__init__.py
+7
-2
nerfacc/cuda/csrc/nerfacc.cpp
nerfacc/cuda/csrc/nerfacc.cpp
+2
-0
nerfacc/cuda/csrc/scan_cub.cu
nerfacc/cuda/csrc/scan_cub.cu
+8
-0
nerfacc/scan.py
nerfacc/scan.py
+211
-31
nerfacc/scan_cub.py
nerfacc/scan_cub.py
+0
-128
tests/test_scan.py
tests/test_scan.py
+8
-11
No files found.
nerfacc/__init__.py
View file @
ea07af8e
...
...
@@ -19,12 +19,6 @@ from .volrend import (
render_weight_from_density
,
rendering
,
)
from
.scan_cub
import
(
exclusive_prod_cub
,
exclusive_sum_cub
,
inclusive_prod_cub
,
inclusive_sum_cub
,
)
__all__
=
[
"__version__"
,
...
...
nerfacc/cuda/__init__.py
View file @
ea07af8e
...
...
@@ -30,12 +30,17 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
exclusive_prod_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_forward"
)
exclusive_prod_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_backward"
)
is_cub_available
=
_make_lazy_cuda_func
(
"is_cub_available"
)
inclusive_sum_cub
=
_make_lazy_cuda_func
(
"inclusive_sum_cub"
)
exclusive_sum_cub
=
_make_lazy_cuda_func
(
"exclusive_sum_cub"
)
inclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_forward"
)
inclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_backward"
)
inclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_backward"
)
exclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_forward"
)
exclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_backward"
)
exclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_backward"
)
# pdf
importance_sampling
=
_make_lazy_cuda_func
(
"importance_sampling"
)
...
...
nerfacc/cuda/csrc/nerfacc.cpp
View file @
ea07af8e
...
...
@@ -38,6 +38,7 @@ torch::Tensor exclusive_prod_backward(
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
);
bool
is_cub_available
();
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
ray_indices
,
torch
::
Tensor
inputs
,
...
...
@@ -131,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC
(
exclusive_prod_forward
);
_REG_FUNC
(
exclusive_prod_backward
);
_REG_FUNC
(
is_cub_available
);
_REG_FUNC
(
inclusive_sum_cub
);
_REG_FUNC
(
exclusive_sum_cub
);
_REG_FUNC
(
inclusive_prod_cub_forward
);
...
...
nerfacc/cuda/csrc/scan_cub.cu
View file @
ea07af8e
...
...
@@ -56,6 +56,14 @@ inline void inclusive_prod_by_key(
}
#endif
bool
is_cub_available
()
{
#if CUB_SUPPORTS_SCAN_BY_KEY()
return
true
;
#else
return
false
;
#endif
}
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
...
...
nerfacc/scan.py
View file @
ea07af8e
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import
warnings
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
.
import
cuda
as
_C
from
.pack
import
pack_info
def
inclusive_sum
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Inclusive Sum that supports flattened tensor.
...
...
@@ -20,11 +24,12 @@ def inclusive_sum(
Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The inclusive sum with the same shape as the input tensor.
...
...
@@ -39,22 +44,43 @@ def inclusive_sum(
tensor([ 1., 3., 3., 7., 12., 6., 13., 21., 30.], device='cuda:0')
"""
if
packed_info
is
None
:
# Batched inclusive sum on the last dimension.
outputs
=
torch
.
cumsum
(
inputs
,
dim
=-
1
)
else
:
# Flattened inclusive sum.
if
indices
is
not
None
and
packed_info
is
not
None
:
raise
ValueError
(
"Only one of `indices` and `packed_info` can be specified."
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_InclusiveSumCUB
.
apply
(
indices
,
inputs
)
else
:
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_InclusiveSum
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
,
False
)
if
indices
is
None
and
packed_info
is
None
:
# Batched inclusive sum on the last dimension.
outputs
=
torch
.
cumsum
(
inputs
,
dim
=-
1
)
return
outputs
def
exclusive_sum
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Exclusive Sum that supports flattened tensor.
...
...
@@ -62,11 +88,12 @@ def exclusive_sum(
Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The exclusive sum with the same shape as the input tensor.
...
...
@@ -81,27 +108,47 @@ def exclusive_sum(
tensor([ 0., 1., 0., 3., 7., 0., 6., 13., 21.], device='cuda:0')
"""
if
packed_info
is
None
:
# Batched exclusive sum on the last dimension.
outputs
=
torch
.
cumsum
(
torch
.
cat
(
[
torch
.
zeros_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
),
dim
=-
1
,
if
indices
is
not
None
and
packed_info
is
not
None
:
raise
ValueError
(
"Only one of `indices` and `packed_info` can be specified."
)
else
:
# Flattened exclusive sum.
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_ExclusiveSumCUB
.
apply
(
indices
,
inputs
)
else
:
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_ExclusiveSum
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
,
False
)
if
indices
is
None
and
packed_info
is
None
:
# Batched exclusive sum on the last dimension.
outputs
=
torch
.
cumsum
(
torch
.
cat
(
[
torch
.
zeros_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
),
dim
=-
1
,
)
return
outputs
def
inclusive_prod
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Inclusive Product that supports flattened tensor.
...
...
@@ -111,11 +158,12 @@ def inclusive_prod(
Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The inclusive product with the same shape as the input tensor.
...
...
@@ -130,22 +178,43 @@ def inclusive_prod(
tensor([1., 2., 3., 12., 60., 6., 42., 336., 3024.], device='cuda:0')
"""
if
packed_info
is
None
:
# Batched inclusive product on the last dimension.
outputs
=
torch
.
cumprod
(
inputs
,
dim
=-
1
)
else
:
# Flattened inclusive product.
if
indices
is
not
None
and
packed_info
is
not
None
:
raise
ValueError
(
"Only one of `indices` and `packed_info` can be specified."
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_InclusiveProdCUB
.
apply
(
indices
,
inputs
)
else
:
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_InclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
if
indices
is
None
and
packed_info
is
None
:
# Batched inclusive product on the last dimension.
outputs
=
torch
.
cumprod
(
inputs
,
dim
=-
1
)
return
outputs
def
exclusive_prod
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Exclusive Product that supports flattened tensor.
...
...
@@ -153,11 +222,12 @@ def exclusive_prod(
Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
The exclusive product with the same shape as the input tensor.
...
...
@@ -173,16 +243,42 @@ def exclusive_prod(
tensor([1., 1., 1., 3., 12., 1., 6., 42., 336.], device='cuda:0')
"""
if
packed_info
is
None
:
if
indices
is
not
None
and
packed_info
is
not
None
:
raise
ValueError
(
"Only one of `indices` and `packed_info` can be specified."
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_ExclusiveProdCUB
.
apply
(
indices
,
inputs
)
else
:
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_ExclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
if
indices
is
None
and
packed_info
is
None
:
# Batched exclusive product on the last dimension.
outputs
=
torch
.
cumprod
(
torch
.
cat
(
[
torch
.
ones_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
),
dim
=-
1
,
)
else
:
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_ExclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
return
outputs
...
...
@@ -286,3 +382,87 @@ class _ExclusiveProd(torch.autograd.Function):
chunk_starts
,
chunk_cnts
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
None
,
grad_inputs
class
_InclusiveSumCUB
(
torch
.
autograd
.
Function
):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_ExclusiveSumCUB
(
torch
.
autograd
.
Function
):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_InclusiveProdCUB
(
torch
.
autograd
.
Function
):
"""Inclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
class
_ExclusiveProdCUB
(
torch
.
autograd
.
Function
):
"""Exclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
nerfacc/scan_cub.py
deleted
100644 → 0
View file @
a0792e88
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import
torch
from
torch
import
Tensor
from
.
import
cuda
as
_C
def
inclusive_sum_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Inclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_InclusiveSum
.
apply
(
indices
,
inputs
)
return
outputs
def
exclusive_sum_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Exclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_ExclusiveSum
.
apply
(
indices
,
inputs
)
return
outputs
def
inclusive_prod_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Inclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_InclusiveProd
.
apply
(
indices
,
inputs
)
return
outputs
def
exclusive_prod_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Exclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_ExclusiveProd
.
apply
(
indices
,
inputs
)
return
outputs
class
_InclusiveSum
(
torch
.
autograd
.
Function
):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_ExclusiveSum
(
torch
.
autograd
.
Function
):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_InclusiveProd
(
torch
.
autograd
.
Function
):
"""Inclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
class
_ExclusiveProd
(
torch
.
autograd
.
Function
):
"""Exclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
tests/test_scan.py
View file @
ea07af8e
...
...
@@ -7,7 +7,6 @@ device = "cuda:0"
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_inclusive_sum
():
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan_cub
import
inclusive_sum_cub
torch
.
manual_seed
(
42
)
...
...
@@ -34,7 +33,7 @@ def test_inclusive_sum():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
inclusive_sum
_cub
(
flatten_data
,
indices
)
outputs3
=
inclusive_sum
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
...
...
@@ -49,7 +48,6 @@ def test_inclusive_sum():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_exclusive_sum
():
from
nerfacc.scan
import
exclusive_sum
from
nerfacc.scan_cub
import
exclusive_sum_cub
torch
.
manual_seed
(
42
)
...
...
@@ -76,7 +74,7 @@ def test_exclusive_sum():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
exclusive_sum
_cub
(
flatten_data
,
indices
)
outputs3
=
exclusive_sum
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
...
...
@@ -93,7 +91,6 @@ def test_exclusive_sum():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_inclusive_prod
():
from
nerfacc.scan
import
inclusive_prod
from
nerfacc.scan_cub
import
inclusive_prod_cub
torch
.
manual_seed
(
42
)
...
...
@@ -120,7 +117,7 @@ def test_inclusive_prod():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
inclusive_prod
_cub
(
flatten_data
,
indices
)
outputs3
=
inclusive_prod
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
...
...
@@ -135,7 +132,6 @@ def test_inclusive_prod():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_exclusive_prod
():
from
nerfacc.scan
import
exclusive_prod
from
nerfacc.scan_cub
import
exclusive_prod_cub
torch
.
manual_seed
(
42
)
...
...
@@ -162,7 +158,7 @@ def test_exclusive_prod():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
outputs3
=
exclusive_prod
_cub
(
flatten_data
,
indices
)
outputs3
=
exclusive_prod
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
...
...
@@ -175,10 +171,11 @@ def test_exclusive_prod():
assert
torch
.
allclose
(
outputs1
,
outputs3
)
assert
torch
.
allclose
(
grad1
,
grad3
)
def
profile
():
import
tqdm
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan_cub
import
inclusive_sum_cub
torch
.
manual_seed
(
42
)
...
...
@@ -202,7 +199,7 @@ def profile():
indices
=
indices
.
flatten
()
torch
.
cuda
.
synchronize
()
for
_
in
tqdm
.
trange
(
2000
):
outputs3
=
inclusive_sum
_cub
(
flatten_data
,
indices
)
outputs3
=
inclusive_sum
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
...
...
@@ -211,4 +208,4 @@ if __name__ == "__main__":
test_exclusive_sum
()
test_inclusive_prod
()
test_exclusive_prod
()
# profile()
\ No newline at end of file
profile
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment