Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nerfacc
Commits
ea07af8e
Commit
ea07af8e
authored
Nov 29, 2023
by
Ruilong Li
Browse files
inclusive sum in one function
parent
a0792e88
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
236 additions
and
178 deletions
+236
-178
nerfacc/__init__.py
nerfacc/__init__.py
+0
-6
nerfacc/cuda/__init__.py
nerfacc/cuda/__init__.py
+7
-2
nerfacc/cuda/csrc/nerfacc.cpp
nerfacc/cuda/csrc/nerfacc.cpp
+2
-0
nerfacc/cuda/csrc/scan_cub.cu
nerfacc/cuda/csrc/scan_cub.cu
+8
-0
nerfacc/scan.py
nerfacc/scan.py
+211
-31
nerfacc/scan_cub.py
nerfacc/scan_cub.py
+0
-128
tests/test_scan.py
tests/test_scan.py
+8
-11
No files found.
nerfacc/__init__.py
View file @
ea07af8e
...
@@ -19,12 +19,6 @@ from .volrend import (
...
@@ -19,12 +19,6 @@ from .volrend import (
render_weight_from_density
,
render_weight_from_density
,
rendering
,
rendering
,
)
)
from
.scan_cub
import
(
exclusive_prod_cub
,
exclusive_sum_cub
,
inclusive_prod_cub
,
inclusive_sum_cub
,
)
__all__
=
[
__all__
=
[
"__version__"
,
"__version__"
,
...
...
nerfacc/cuda/__init__.py
View file @
ea07af8e
...
@@ -30,12 +30,17 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
...
@@ -30,12 +30,17 @@ inclusive_prod_backward = _make_lazy_cuda_func("inclusive_prod_backward")
exclusive_prod_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_forward"
)
exclusive_prod_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_forward"
)
exclusive_prod_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_backward"
)
exclusive_prod_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_backward"
)
is_cub_available
=
_make_lazy_cuda_func
(
"is_cub_available"
)
inclusive_sum_cub
=
_make_lazy_cuda_func
(
"inclusive_sum_cub"
)
inclusive_sum_cub
=
_make_lazy_cuda_func
(
"inclusive_sum_cub"
)
exclusive_sum_cub
=
_make_lazy_cuda_func
(
"exclusive_sum_cub"
)
exclusive_sum_cub
=
_make_lazy_cuda_func
(
"exclusive_sum_cub"
)
inclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_forward"
)
inclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_forward"
)
inclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_backward"
)
inclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"inclusive_prod_cub_backward"
)
exclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_forward"
)
exclusive_prod_cub_forward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_forward"
)
exclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_backward"
)
exclusive_prod_cub_backward
=
_make_lazy_cuda_func
(
"exclusive_prod_cub_backward"
)
# pdf
# pdf
importance_sampling
=
_make_lazy_cuda_func
(
"importance_sampling"
)
importance_sampling
=
_make_lazy_cuda_func
(
"importance_sampling"
)
...
...
nerfacc/cuda/csrc/nerfacc.cpp
View file @
ea07af8e
...
@@ -38,6 +38,7 @@ torch::Tensor exclusive_prod_backward(
...
@@ -38,6 +38,7 @@ torch::Tensor exclusive_prod_backward(
torch
::
Tensor
outputs
,
torch
::
Tensor
outputs
,
torch
::
Tensor
grad_outputs
);
torch
::
Tensor
grad_outputs
);
bool
is_cub_available
();
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
ray_indices
,
torch
::
Tensor
ray_indices
,
torch
::
Tensor
inputs
,
torch
::
Tensor
inputs
,
...
@@ -131,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -131,6 +132,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC
(
exclusive_prod_forward
);
_REG_FUNC
(
exclusive_prod_forward
);
_REG_FUNC
(
exclusive_prod_backward
);
_REG_FUNC
(
exclusive_prod_backward
);
_REG_FUNC
(
is_cub_available
);
_REG_FUNC
(
inclusive_sum_cub
);
_REG_FUNC
(
inclusive_sum_cub
);
_REG_FUNC
(
exclusive_sum_cub
);
_REG_FUNC
(
exclusive_sum_cub
);
_REG_FUNC
(
inclusive_prod_cub_forward
);
_REG_FUNC
(
inclusive_prod_cub_forward
);
...
...
nerfacc/cuda/csrc/scan_cub.cu
View file @
ea07af8e
...
@@ -56,6 +56,14 @@ inline void inclusive_prod_by_key(
...
@@ -56,6 +56,14 @@ inline void inclusive_prod_by_key(
}
}
#endif
#endif
bool
is_cub_available
()
{
#if CUB_SUPPORTS_SCAN_BY_KEY()
return
true
;
#else
return
false
;
#endif
}
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
inclusive_sum_cub
(
torch
::
Tensor
indices
,
torch
::
Tensor
indices
,
torch
::
Tensor
inputs
,
torch
::
Tensor
inputs
,
...
...
nerfacc/scan.py
View file @
ea07af8e
"""
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
"""
import
warnings
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
.
import
cuda
as
_C
from
.
import
cuda
as
_C
from
.pack
import
pack_info
def
inclusive_sum
(
def
inclusive_sum
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Inclusive Sum that supports flattened tensor.
"""Inclusive Sum that supports flattened tensor.
...
@@ -20,11 +24,12 @@ def inclusive_sum(
...
@@ -20,11 +24,12 @@ def inclusive_sum(
Args:
Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed
If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None.
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
Returns:
The inclusive sum with the same shape as the input tensor.
The inclusive sum with the same shape as the input tensor.
...
@@ -39,22 +44,43 @@ def inclusive_sum(
...
@@ -39,22 +44,43 @@ def inclusive_sum(
tensor([ 1., 3., 3., 7., 12., 6., 13., 21., 30.], device='cuda:0')
tensor([ 1., 3., 3., 7., 12., 6., 13., 21., 30.], device='cuda:0')
"""
"""
if
packed_info
is
None
:
if
indices
is
not
None
and
packed_info
is
not
None
:
# Batched inclusive sum on the last dimension.
raise
ValueError
(
outputs
=
torch
.
cumsum
(
inputs
,
dim
=-
1
)
"Only one of `indices` and `packed_info` can be specified."
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_InclusiveSumCUB
.
apply
(
indices
,
inputs
)
else
:
else
:
# Flattened inclusive sum.
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_InclusiveSum
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
,
False
)
outputs
=
_InclusiveSum
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
,
False
)
if
indices
is
None
and
packed_info
is
None
:
# Batched inclusive sum on the last dimension.
outputs
=
torch
.
cumsum
(
inputs
,
dim
=-
1
)
return
outputs
return
outputs
def
exclusive_sum
(
def
exclusive_sum
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Exclusive Sum that supports flattened tensor.
"""Exclusive Sum that supports flattened tensor.
...
@@ -62,11 +88,12 @@ def exclusive_sum(
...
@@ -62,11 +88,12 @@ def exclusive_sum(
Args:
Args:
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
inputs: The tensor to be summed. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the sum is computed
If None, the input is assumed to be a N-D tensor and the sum is computed
along the last dimension. Default is None.
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
Returns:
The exclusive sum with the same shape as the input tensor.
The exclusive sum with the same shape as the input tensor.
...
@@ -81,27 +108,47 @@ def exclusive_sum(
...
@@ -81,27 +108,47 @@ def exclusive_sum(
tensor([ 0., 1., 0., 3., 7., 0., 6., 13., 21.], device='cuda:0')
tensor([ 0., 1., 0., 3., 7., 0., 6., 13., 21.], device='cuda:0')
"""
"""
if
packed_info
is
None
:
if
indices
is
not
None
and
packed_info
is
not
None
:
# Batched exclusive sum on the last dimension.
raise
ValueError
(
outputs
=
torch
.
cumsum
(
"Only one of `indices` and `packed_info` can be specified."
torch
.
cat
(
[
torch
.
zeros_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
),
dim
=-
1
,
)
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_ExclusiveSumCUB
.
apply
(
indices
,
inputs
)
else
:
else
:
# Flattened exclusive sum.
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_ExclusiveSum
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
,
False
)
outputs
=
_ExclusiveSum
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
,
False
)
if
indices
is
None
and
packed_info
is
None
:
# Batched exclusive sum on the last dimension.
outputs
=
torch
.
cumsum
(
torch
.
cat
(
[
torch
.
zeros_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
),
dim
=-
1
,
)
return
outputs
return
outputs
def
inclusive_prod
(
def
inclusive_prod
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Inclusive Product that supports flattened tensor.
"""Inclusive Product that supports flattened tensor.
...
@@ -111,11 +158,12 @@ def inclusive_prod(
...
@@ -111,11 +158,12 @@ def inclusive_prod(
Args:
Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed
If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None.
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
Returns:
The inclusive product with the same shape as the input tensor.
The inclusive product with the same shape as the input tensor.
...
@@ -130,22 +178,43 @@ def inclusive_prod(
...
@@ -130,22 +178,43 @@ def inclusive_prod(
tensor([1., 2., 3., 12., 60., 6., 42., 336., 3024.], device='cuda:0')
tensor([1., 2., 3., 12., 60., 6., 42., 336., 3024.], device='cuda:0')
"""
"""
if
packed_info
is
None
:
if
indices
is
not
None
and
packed_info
is
not
None
:
# Batched inclusive product on the last dimension.
raise
ValueError
(
outputs
=
torch
.
cumprod
(
inputs
,
dim
=-
1
)
"Only one of `indices` and `packed_info` can be specified."
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_InclusiveProdCUB
.
apply
(
indices
,
inputs
)
else
:
else
:
# Flattened inclusive product.
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_InclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
outputs
=
_InclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
if
indices
is
None
and
packed_info
is
None
:
# Batched inclusive product on the last dimension.
outputs
=
torch
.
cumprod
(
inputs
,
dim
=-
1
)
return
outputs
return
outputs
def
exclusive_prod
(
def
exclusive_prod
(
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
inputs
:
Tensor
,
packed_info
:
Optional
[
Tensor
]
=
None
,
indices
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Exclusive Product that supports flattened tensor.
"""Exclusive Product that supports flattened tensor.
...
@@ -153,11 +222,12 @@ def exclusive_prod(
...
@@ -153,11 +222,12 @@ def exclusive_prod(
Args:
Args:
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
inputs: The tensor to be producted. Can be either a N-D tensor, or a flattened
tensor with `packed_info` specified.
tensor with
either
`packed_info`
or `indices`
specified.
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
packed_info: A tensor of shape (n_rays, 2) that specifies the start and count
of each chunk in the flattened input tensor, with in total n_rays chunks.
of each chunk in the flattened input tensor, with in total n_rays chunks.
If None, the input is assumed to be a N-D tensor and the product is computed
If None, the input is assumed to be a N-D tensor and the product is computed
along the last dimension. Default is None.
along the last dimension. Default is None.
indices: A flattened tensor with the same shape as `inputs`.
Returns:
Returns:
The exclusive product with the same shape as the input tensor.
The exclusive product with the same shape as the input tensor.
...
@@ -173,16 +243,42 @@ def exclusive_prod(
...
@@ -173,16 +243,42 @@ def exclusive_prod(
tensor([1., 1., 1., 3., 12., 1., 6., 42., 336.], device='cuda:0')
tensor([1., 1., 1., 3., 12., 1., 6., 42., 336.], device='cuda:0')
"""
"""
if
packed_info
is
None
:
if
indices
is
not
None
and
packed_info
is
not
None
:
raise
ValueError
(
"Only one of `indices` and `packed_info` can be specified."
)
if
indices
is
not
None
:
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
==
inputs
.
shape
),
"indices must be 1-D with the same shape as inputs."
if
_C
.
is_cub_available
():
# Use CUB if available
outputs
=
_ExclusiveProdCUB
.
apply
(
indices
,
inputs
)
else
:
warnings
.
warn
(
"Passing in `indices` without CUB available is slow. Considering passing in `packed_info` instead."
)
packed_info
=
pack_info
(
ray_indices
=
indices
)
if
packed_info
is
not
None
:
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
packed_info
.
dim
()
==
2
and
packed_info
.
shape
[
-
1
]
==
2
),
"packed_info must be 2-D with shape (B, 2)."
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_ExclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
if
indices
is
None
and
packed_info
is
None
:
# Batched exclusive product on the last dimension.
outputs
=
torch
.
cumprod
(
outputs
=
torch
.
cumprod
(
torch
.
cat
(
torch
.
cat
(
[
torch
.
ones_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
[
torch
.
ones_like
(
inputs
[...,
:
1
]),
inputs
[...,
:
-
1
]],
dim
=-
1
),
),
dim
=-
1
,
dim
=-
1
,
)
)
else
:
chunk_starts
,
chunk_cnts
=
packed_info
.
unbind
(
dim
=-
1
)
outputs
=
_ExclusiveProd
.
apply
(
chunk_starts
,
chunk_cnts
,
inputs
)
return
outputs
return
outputs
...
@@ -286,3 +382,87 @@ class _ExclusiveProd(torch.autograd.Function):
...
@@ -286,3 +382,87 @@ class _ExclusiveProd(torch.autograd.Function):
chunk_starts
,
chunk_cnts
,
inputs
,
outputs
,
grad_outputs
chunk_starts
,
chunk_cnts
,
inputs
,
outputs
,
grad_outputs
)
)
return
None
,
None
,
grad_inputs
return
None
,
None
,
grad_inputs
class
_InclusiveSumCUB
(
torch
.
autograd
.
Function
):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_ExclusiveSumCUB
(
torch
.
autograd
.
Function
):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_InclusiveProdCUB
(
torch
.
autograd
.
Function
):
"""Inclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
class
_ExclusiveProdCUB
(
torch
.
autograd
.
Function
):
"""Exclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
nerfacc/scan_cub.py
deleted
100644 → 0
View file @
a0792e88
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""
import
torch
from
torch
import
Tensor
from
.
import
cuda
as
_C
def
inclusive_sum_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Inclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_InclusiveSum
.
apply
(
indices
,
inputs
)
return
outputs
def
exclusive_sum_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Exclusive Sum that supports flattened tensor with CUB."""
# Flattened inclusive sum.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_ExclusiveSum
.
apply
(
indices
,
inputs
)
return
outputs
def
inclusive_prod_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Inclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_InclusiveProd
.
apply
(
indices
,
inputs
)
return
outputs
def
exclusive_prod_cub
(
inputs
:
Tensor
,
indices
:
Tensor
)
->
Tensor
:
"""Exclusive Prod that supports flattened tensor with CUB."""
# Flattened inclusive prod.
assert
inputs
.
dim
()
==
1
,
"inputs must be flattened."
assert
(
indices
.
dim
()
==
1
and
indices
.
shape
[
0
]
==
inputs
.
shape
[
0
]
),
"indices must be 1-D with the same shape as inputs."
outputs
=
_ExclusiveProd
.
apply
(
indices
,
inputs
)
return
outputs
class
_InclusiveSum
(
torch
.
autograd
.
Function
):
"""Inclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_ExclusiveSum
(
torch
.
autograd
.
Function
):
"""Exclusive Sum on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_sum_cub
(
indices
,
inputs
,
False
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
(
indices
,)
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_sum_cub
(
indices
,
grad_outputs
,
True
)
return
None
,
grad_inputs
class
_InclusiveProd
(
torch
.
autograd
.
Function
):
"""Inclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
inclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
inclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
class
_ExclusiveProd
(
torch
.
autograd
.
Function
):
"""Exclusive Product on a Flattened Tensor with CUB."""
@
staticmethod
def
forward
(
ctx
,
indices
,
inputs
):
indices
=
indices
.
contiguous
()
inputs
=
inputs
.
contiguous
()
outputs
=
_C
.
exclusive_prod_cub_forward
(
indices
,
inputs
)
if
ctx
.
needs_input_grad
[
1
]:
ctx
.
save_for_backward
(
indices
,
inputs
,
outputs
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
grad_outputs
):
grad_outputs
=
grad_outputs
.
contiguous
()
indices
,
inputs
,
outputs
=
ctx
.
saved_tensors
grad_inputs
=
_C
.
exclusive_prod_cub_backward
(
indices
,
inputs
,
outputs
,
grad_outputs
)
return
None
,
grad_inputs
tests/test_scan.py
View file @
ea07af8e
...
@@ -7,7 +7,6 @@ device = "cuda:0"
...
@@ -7,7 +7,6 @@ device = "cuda:0"
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_inclusive_sum
():
def
test_inclusive_sum
():
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan_cub
import
inclusive_sum_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -34,7 +33,7 @@ def test_inclusive_sum():
...
@@ -34,7 +33,7 @@ def test_inclusive_sum():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
indices
=
indices
.
flatten
()
outputs3
=
inclusive_sum
_cub
(
flatten_data
,
indices
)
outputs3
=
inclusive_sum
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
data
.
grad
.
zero_
()
...
@@ -49,7 +48,6 @@ def test_inclusive_sum():
...
@@ -49,7 +48,6 @@ def test_inclusive_sum():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_exclusive_sum
():
def
test_exclusive_sum
():
from
nerfacc.scan
import
exclusive_sum
from
nerfacc.scan
import
exclusive_sum
from
nerfacc.scan_cub
import
exclusive_sum_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -76,7 +74,7 @@ def test_exclusive_sum():
...
@@ -76,7 +74,7 @@ def test_exclusive_sum():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
indices
=
indices
.
flatten
()
outputs3
=
exclusive_sum
_cub
(
flatten_data
,
indices
)
outputs3
=
exclusive_sum
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
data
.
grad
.
zero_
()
...
@@ -93,7 +91,6 @@ def test_exclusive_sum():
...
@@ -93,7 +91,6 @@ def test_exclusive_sum():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_inclusive_prod
():
def
test_inclusive_prod
():
from
nerfacc.scan
import
inclusive_prod
from
nerfacc.scan
import
inclusive_prod
from
nerfacc.scan_cub
import
inclusive_prod_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -120,7 +117,7 @@ def test_inclusive_prod():
...
@@ -120,7 +117,7 @@ def test_inclusive_prod():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
indices
=
indices
.
flatten
()
outputs3
=
inclusive_prod
_cub
(
flatten_data
,
indices
)
outputs3
=
inclusive_prod
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
data
.
grad
.
zero_
()
...
@@ -135,7 +132,6 @@ def test_inclusive_prod():
...
@@ -135,7 +132,6 @@ def test_inclusive_prod():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
,
reason
=
"No CUDA device"
)
def
test_exclusive_prod
():
def
test_exclusive_prod
():
from
nerfacc.scan
import
exclusive_prod
from
nerfacc.scan
import
exclusive_prod
from
nerfacc.scan_cub
import
exclusive_prod_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -162,7 +158,7 @@ def test_exclusive_prod():
...
@@ -162,7 +158,7 @@ def test_exclusive_prod():
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
torch
.
arange
(
data
.
shape
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
repeat_interleave
(
data
.
shape
[
1
])
indices
=
indices
.
flatten
()
indices
=
indices
.
flatten
()
outputs3
=
exclusive_prod
_cub
(
flatten_data
,
indices
)
outputs3
=
exclusive_prod
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
outputs3
.
sum
().
backward
()
grad3
=
data
.
grad
.
clone
()
grad3
=
data
.
grad
.
clone
()
data
.
grad
.
zero_
()
data
.
grad
.
zero_
()
...
@@ -175,10 +171,11 @@ def test_exclusive_prod():
...
@@ -175,10 +171,11 @@ def test_exclusive_prod():
assert
torch
.
allclose
(
outputs1
,
outputs3
)
assert
torch
.
allclose
(
outputs1
,
outputs3
)
assert
torch
.
allclose
(
grad1
,
grad3
)
assert
torch
.
allclose
(
grad1
,
grad3
)
def
profile
():
def
profile
():
import
tqdm
import
tqdm
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan
import
inclusive_sum
from
nerfacc.scan_cub
import
inclusive_sum_cub
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
...
@@ -202,7 +199,7 @@ def profile():
...
@@ -202,7 +199,7 @@ def profile():
indices
=
indices
.
flatten
()
indices
=
indices
.
flatten
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
for
_
in
tqdm
.
trange
(
2000
):
for
_
in
tqdm
.
trange
(
2000
):
outputs3
=
inclusive_sum
_cub
(
flatten_data
,
indices
)
outputs3
=
inclusive_sum
(
flatten_data
,
indices
=
indices
)
outputs3
.
sum
().
backward
()
outputs3
.
sum
().
backward
()
...
@@ -211,4 +208,4 @@ if __name__ == "__main__":
...
@@ -211,4 +208,4 @@ if __name__ == "__main__":
test_exclusive_sum
()
test_exclusive_sum
()
test_inclusive_prod
()
test_inclusive_prod
()
test_exclusive_prod
()
test_exclusive_prod
()
# profile()
profile
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment