Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
one
spconv
Commits
6c767a51
Commit
6c767a51
authored
May 21, 2020
by
Yan Yan
Browse files
working on remove functor
parent
19e73bbe
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
584 additions
and
413 deletions
+584
-413
include/spconv/fused_spconv_ops.h
include/spconv/fused_spconv_ops.h
+6
-30
include/spconv/indice.h
include/spconv/indice.h
+15
-0
include/spconv/reordering.h
include/spconv/reordering.h
+10
-15
include/spconv/spconv_ops.h
include/spconv/spconv_ops.h
+9
-1
include/tensorview/tensor.h
include/tensorview/tensor.h
+34
-6
spconv/ops.py
spconv/ops.py
+2
-11
src/spconv/all.cc
src/spconv/all.cc
+4
-5
src/spconv/indice.cc
src/spconv/indice.cc
+74
-0
src/spconv/indice.cu
src/spconv/indice.cu
+61
-43
src/spconv/reordering.cc
src/spconv/reordering.cc
+49
-48
src/spconv/reordering.cu
src/spconv/reordering.cu
+131
-120
src/spconv/spconv_ops.cc
src/spconv/spconv_ops.cc
+183
-129
test/test_conv.py
test/test_conv.py
+6
-5
No files found.
include/spconv/fused_spconv_ops.h
View file @
6c767a51
...
...
@@ -24,7 +24,6 @@
namespace
spconv
{
// torch.jit's doc says only support int64, so we need to convert to int32.
template
<
typename
T
>
torch
::
Tensor
fusedIndiceConvBatchNorm
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
bias
,
torch
::
Tensor
indicePairs
,
...
...
@@ -80,31 +79,17 @@ fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters,
continue
;
}
// auto timer = spconv::CudaContextTimer<>();
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
<
T
>
(),
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
<
T
>
(),
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
(),
{
nHot
,
numInPlanes
},
options
);
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseGatherFunctor
<
tv
::
CPU
,
T
,
int
>
gatherFtor
;
gatherFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
sparse_gather_cpu
(
inputBuffer
,
features
,
indicePairs
[
i
][
inverse
],
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseGatherFunctor
<
tv
::
GPU
,
T
,
int
>
gatherFtor
;
gatherFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
TV_CHECK_CUDA_ERR
();
/* slower than SparseGatherFunctor, may due to int->long conversion
auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64);
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(),
{nHot}, indicePairOptions); torch::index_select_out(inputBufferBlob,
features, 0, indicePairBlob);*/
sparse_gather_cuda
(
inputBuffer
,
features
,
indicePairs
[
i
][
inverse
],
nHot
);
}
#endif
else
{
...
...
@@ -116,20 +101,11 @@ fusedIndiceConvBatchNorm(torch::Tensor features, torch::Tensor filters,
// totalGEMMTime += timer.report() / 1000.0;
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
CPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
output
),
tv
::
torch2tv
<
const
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
sparse_scatter_add_cpu
(
outputBuffer
,
output
,
indicePairs
[
i
][
!
inverse
],
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
GPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
output
),
tv
::
torch2tv
<
const
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
TV_CHECK_CUDA_ERR
();
sparse_scatter_add_cuda
(
outputBuffer
,
output
,
indicePairs
[
i
][
!
inverse
],
nHot
);
}
#endif
else
{
...
...
include/spconv/indice.h
View file @
6c767a51
...
...
@@ -97,6 +97,21 @@ int create_submconv_indice_pair_cuda(
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
);
int
create_conv_indice_pair_cpu
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
indicesOut
,
torch
::
Tensor
gridsOut
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
);
int
create_submconv_indice_pair_cpu
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
gridsOut
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
);
}
// namespace spconv
#endif
\ No newline at end of file
include/spconv/reordering.h
View file @
6c767a51
...
...
@@ -15,24 +15,19 @@
#ifndef SPARSE_REORDERING_FUNCTOR_H_
#define SPARSE_REORDERING_FUNCTOR_H_
#include <tensorview/tensorview.h>
#include <torch/script.h>
namespace
spconv
{
namespace
functor
{
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseGatherFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
buffer
,
tv
::
TensorView
<
const
T
>
features
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
);
};
void
sparse_gather_cuda
(
torch
::
Tensor
buffer
,
torch
::
Tensor
features
,
torch
::
Tensor
indices
,
int
size
);
void
sparse_scatter_add_cuda
(
torch
::
Tensor
buffer
,
torch
::
Tensor
outFeatures
,
torch
::
Tensor
indices
,
int
size
);
void
sparse_gather_cpu
(
torch
::
Tensor
buffer
,
torch
::
Tensor
features
,
torch
::
Tensor
indices
,
int
size
);
void
sparse_scatter_add_cpu
(
torch
::
Tensor
buffer
,
torch
::
Tensor
outFeatures
,
torch
::
Tensor
indices
,
int
size
);
template
<
typename
Device
,
typename
T
,
typename
Index
>
struct
SparseScatterAddFunctor
{
void
operator
()(
const
Device
&
d
,
tv
::
TensorView
<
T
>
out_features
,
tv
::
TensorView
<
const
T
>
buffer
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
,
bool
stable
=
false
);
};
}
// namespace functor
}
// namespace spconv
#endif
\ No newline at end of file
include/spconv/spconv_ops.h
View file @
6c767a51
...
...
@@ -198,6 +198,15 @@ getIndicePair(torch::Tensor indices, int64_t batchSize,
}
}
std
::
vector
<
torch
::
Tensor
>
getIndicePairV2
(
torch
::
Tensor
indices
,
int64_t
batchSize
,
std
::
vector
<
int64_t
>
outSpatialShape
,
std
::
vector
<
int64_t
>
spatialShape
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outPadding
,
int64_t
_subM
,
int64_t
_transpose
,
int64_t
_useHash
);
template
<
unsigned
NDim
>
std
::
vector
<
torch
::
Tensor
>
getIndicePairPreGrid
(
torch
::
Tensor
indices
,
torch
::
Tensor
gridOut
,
int64_t
batchSize
,
...
...
@@ -333,7 +342,6 @@ std::vector<torch::Tensor> getIndicePairPreGrid(
torch
::
Tensor
indiceConv
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
);
std
::
vector
<
torch
::
Tensor
>
indiceConvBackward
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
outGrad
,
torch
::
Tensor
indicePairs
,
...
...
include/tensorview/tensor.h
View file @
6c767a51
...
...
@@ -52,6 +52,10 @@ enum DType {
namespace
detail
{
using
dtype_collection_t
=
tv
::
mp_list_c
<
int
,
float32
,
int32
,
int16
,
int8
,
float64
,
bool_
,
uint8
,
float16
,
int64
,
uint16
,
uint32
,
uint64
>
;
using
all_tensor_types_t
=
std
::
tuple
<
float
,
double
,
int8_t
,
int16_t
,
int32_t
,
int64_t
,
uint8_t
,
uint16_t
,
uint32_t
,
uint64_t
,
bool
>
;
...
...
@@ -305,7 +309,7 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) {
static_assert
(
sizeof
...(
Ts
)
>
0
,
"you need to provide at least one type"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list
<
Ts
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
type_v
<
decltype
(
I
)
>
==
t
)
{
if
(
type_v
<
decltype
(
I
)
>
==
t
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
decltype
(
I
)());
notFound
=
false
;
}
...
...
@@ -325,7 +329,7 @@ template <typename T, T... Is, typename F> void dispatch_scalar(T idx, F &&f) {
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
T
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
T
(
I
)
==
idx
)
{
if
(
T
(
I
)
==
idx
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
...
...
@@ -343,7 +347,27 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
int
(
I
)
==
idx
)
{
if
(
decltype
(
I
)
::
value
==
idx
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
});
if
(
notFound
)
{
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
[
=
,
&
ss
](
auto
I
)
{
ss
<<
decltype
(
I
)
::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
template
<
int
...
Is
,
typename
F
,
class
BinaryPredicate
>
void
dispatch_int
(
int
idx
,
BinaryPredicate
p
,
F
&&
f
)
{
// BinaryPredicate: BinaryPredicate(idx, candidate)
static_assert
(
sizeof
...(
Is
)
>
0
,
"you need to provide at least one candidate"
);
bool
notFound
=
true
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
([
=
,
&
notFound
,
&
f
](
auto
I
)
{
if
(
p
(
idx
,
decltype
(
I
)
::
value
)
&&
notFound
)
{
std
::
forward
<
F
>
(
f
)(
I
);
notFound
=
false
;
}
...
...
@@ -351,7 +375,7 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
if
(
notFound
)
{
std
::
stringstream
ss
;
mp_for_each
<
mp_list_c
<
int
,
Is
...
>>
(
[
=
,
&
ss
](
auto
I
)
{
ss
<<
int
(
I
)
<<
" "
;
});
[
=
,
&
ss
](
auto
I
)
{
ss
<<
decltype
(
I
)
::
value
<<
" "
;
});
TV_THROW_RT_ERR
(
"unknown value"
,
idx
,
", available:"
,
ss
.
str
());
}
}
...
...
@@ -373,12 +397,16 @@ struct Dispatch<T<Args...>> {
template
<
class
T
>
struct
DispatchInt
;
template
<
template
<
int
...
>
class
T
,
int
...
Ints
>
struct
DispatchInt
<
T
<
Ints
...
>>
{
template
<
template
<
class
...
>
class
Tin
,
template
<
class
,
int
>
class
T
,
int
...
Ints
>
struct
DispatchInt
<
T
in
<
T
<
int
,
Ints
>
...
>>
{
template
<
typename
F
>
inline
void
operator
()(
int
t
,
F
&&
f
)
{
return
dispatch_int
<
Ints
...
>
(
t
,
std
::
forward
<
F
>
(
f
));
}
template
<
typename
F
,
typename
BinaryPredicate
>
inline
void
operator
()(
int
t
,
BinaryPredicate
p
,
F
&&
f
)
{
return
dispatch_int
<
Ints
...
>
(
t
,
p
,
std
::
forward
<
F
>
(
f
));
}
};
constexpr
size_t
kTensorMaxDim
=
10
;
using
TensorShape
=
ShapeBase
<
kTensorMaxDim
,
int64_t
>
;
...
...
spconv/ops.py
View file @
6c767a51
...
...
@@ -81,16 +81,7 @@ def get_indice_pairs(indices,
else
:
out_shape
=
spatial_shape
if
grid
is
None
:
if
ndim
==
2
:
get_indice_pairs_func
=
torch
.
ops
.
spconv
.
get_indice_pairs_2d
elif
ndim
==
3
:
get_indice_pairs_func
=
torch
.
ops
.
spconv
.
get_indice_pairs_3d
elif
ndim
==
4
:
get_indice_pairs_func
=
torch
.
ops
.
spconv
.
get_indice_pairs_4d
else
:
raise
NotImplementedError
res
=
get_indice_pairs_func
(
indices
,
batch_size
,
out_shape
,
res
=
torch
.
ops
.
spconv
.
get_indice_pairs_v2
(
indices
,
batch_size
,
out_shape
,
spatial_shape
,
ksize
,
stride
,
padding
,
dilation
,
out_padding
,
int
(
subm
),
int
(
transpose
),
int
(
use_hash
))
...
...
@@ -115,7 +106,7 @@ def indice_conv(features,
num_activate_out
,
inverse
=
False
,
subm
=
False
):
return
torch
.
ops
.
spconv
.
indice_conv
(
features
,
filters
,
indice_pairs
,
return
torch
.
ops
.
spconv
.
indice_conv
_v2
(
features
,
filters
,
indice_pairs
,
indice_pair_num
,
num_activate_out
,
int
(
inverse
),
int
(
subm
))
...
...
src/spconv/all.cc
View file @
6c767a51
...
...
@@ -12,28 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <spconv/fused_spconv_ops.h>
#include <spconv/nms_ops.h>
#include <spconv/pillar_scatter_ops.h>
#include <spconv/pool_ops.h>
#include <spconv/spconv_ops.h>
#include <torch/script.h>
#include <spconv/fused_spconv_ops.h>
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"spconv::get_indice_pairs_2d"
,
&
spconv
::
getIndicePair
<
2
>
)
.
op
(
"spconv::get_indice_pairs_3d"
,
&
spconv
::
getIndicePair
<
3
>
)
.
op
(
"spconv::get_indice_pairs_4d"
,
&
spconv
::
getIndicePair
<
4
>
)
.
op
(
"spconv::get_indice_pairs_v2"
,
&
spconv
::
getIndicePairV2
)
.
op
(
"spconv::get_indice_pairs_grid_2d"
,
&
spconv
::
getIndicePairPreGrid
<
2
>
)
.
op
(
"spconv::get_indice_pairs_grid_3d"
,
&
spconv
::
getIndicePairPreGrid
<
3
>
)
.
op
(
"spconv::indice_conv"
,
&
spconv
::
indiceConv
)
.
op
(
"spconv::indice_conv_backward"
,
&
spconv
::
indiceConvBackward
)
.
op
(
"spconv::fused_indice_conv_fp32"
,
&
spconv
::
fusedIndiceConvBatchNorm
<
float
>
)
.
op
(
"spconv::fused_indice_conv_half"
,
&
spconv
::
fusedIndiceConvBatchNorm
<
at
::
Half
>
)
.
op
(
"spconv::fused_indice_conv_bn"
,
&
spconv
::
fusedIndiceConvBatchNorm
)
.
op
(
"spconv::indice_maxpool_fp32"
,
&
spconv
::
indiceMaxPool
<
float
>
)
.
op
(
"spconv::indice_maxpool_backward_fp32"
,
&
spconv
::
indiceMaxPoolBackward
<
float
>
)
...
...
src/spconv/indice.cc
View file @
6c767a51
...
...
@@ -16,6 +16,7 @@
#include <spconv/geometry.h>
#include <spconv/indice.h>
#include <spconv/spconv_ops.h>
#include <tensorview/tensor.h>
#include <torch/script.h>
namespace
spconv
{
...
...
@@ -253,6 +254,79 @@ Index getIndicePairsSubM(tv::TensorView<const Index> indicesIn,
}
#endif
int
create_conv_indice_pair_cpu
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
indicesOut
,
torch
::
Tensor
gridsOut
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
)
{
auto
ndim
=
outSpatialShape
.
size
();
auto
numActIn
=
indicesIn
.
size
(
0
);
int
batchSize
=
gridsOut
.
size
(
0
);
auto
kernelVolume
=
indicePairs
.
size
(
0
);
if
(
numActIn
==
0
)
return
0
;
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
di
(
dilation
.
begin
(),
dilation
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ou
(
outSpatialShape
.
begin
(),
outSpatialShape
.
end
());
if
(
transpose
)
numActIn
=
getIndicePairsDeConv
<
Index
,
IndexGrid
,
NDim
>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicesOut
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
.
data
(),
st
.
data
(),
pa
.
data
(),
di
.
data
(),
ou
.
data
());
else
numActIn
=
getIndicePairsConv
<
Index
,
IndexGrid
,
NDim
>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicesOut
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
.
data
(),
st
.
data
(),
pa
.
data
(),
di
.
data
(),
ou
.
data
());
});
});
return
numActIn
;
}
int
create_submconv_indice_pair_cpu
(
torch
::
Tensor
indicesIn
,
torch
::
Tensor
gridsOut
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outSpatialShape
,
bool
transpose
,
bool
resetGrid
,
bool
useHash
)
{
auto
ndim
=
outSpatialShape
.
size
();
auto
numActIn
=
indicesIn
.
size
(
0
);
int
batchSize
=
gridsOut
.
size
(
0
);
auto
kernelVolume
=
indicePairs
.
size
(
0
);
if
(
numActIn
==
0
)
return
0
;
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
di
(
dilation
.
begin
(),
dilation
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ou
(
outSpatialShape
.
begin
(),
outSpatialShape
.
end
());
numActIn
=
getIndicePairsSubM
<
Index
,
IndexGrid
,
NDim
>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
.
data
(),
st
.
data
(),
pa
.
data
(),
di
.
data
(),
ou
.
data
());
});
});
return
numActIn
;
}
namespace
functor
{
template
<
typename
Index
,
typename
IndexGrid
,
unsigned
NDim
>
struct
CreateConvIndicePairFunctor
<
tv
::
CPU
,
Index
,
IndexGrid
,
NDim
>
{
...
...
src/spconv/indice.cu
View file @
6c767a51
...
...
@@ -38,37 +38,43 @@ int create_conv_indice_pair_p1_cuda(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
ndim
=
kernelSize
.
size
();
auto
numActIn
=
indicesIn
.
size
(
0
);
auto
kernelVolume
=
indicePairs
.
size
(
0
);
if
(
numActIn
==
0
)
return
0
;
// dispatch_torch must be in outside, this is a gcc bug, fixed in gcc 8.
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
I
;
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
di
(
dilation
.
begin
(),
dilation
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
ou
(
outSpatialShape
.
begin
(),
outSpatialShape
.
end
());
if
(
transpose
)
{
prepareDeConvIndicePairsKernel
<
Index
,
NDim
,
4096
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
ks
,
st
,
pa
,
di
,
ou
);
}
else
{
prepareIndicePairsKernel
<
Index
,
NDim
,
4096
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
ks
,
st
,
pa
,
di
,
ou
);
}
tv
::
dispatch_int
<
16
,
32
,
256
,
4096
>
(
kernelVolume
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
constexpr
int
MaxKernelVolume
=
decltype
(
I2
)
::
value
;
if
(
transpose
)
{
prepareDeConvIndicePairsKernel
<
Index
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
ks
,
st
,
pa
,
di
,
ou
);
TV_CHECK_CUDA_ERR_V2
(
"prepareDeConvIndicePairsKernel failed"
);
}
else
{
prepareIndicePairsKernel
<
Index
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
ks
,
st
,
pa
,
di
,
ou
);
TV_CHECK_CUDA_ERR_V2
(
"prepareIndicePairsKernel failed"
);
}
});
});
});
return
1
;
...
...
@@ -88,12 +94,11 @@ int create_conv_indice_pair_p2_cuda(
auto
kernelVolume
=
indicePairs
.
size
(
0
);
if
(
numActIn
==
0
)
return
0
;
// dispatch_torch must be in outside, this is a gcc bug, fixed in gcc 8.
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
I
;
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
using
IndexGrid
=
int32_t
;
tv
::
SimpleVector
<
Index
,
NDim
>
ou
(
outSpatialShape
.
begin
(),
outSpatialShape
.
end
());
...
...
@@ -122,6 +127,8 @@ int create_conv_indice_pair_p2_cuda(
<<<
tv
::
cuda
::
getBlocks
(
numAct
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesOut
),
numAct
,
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
ou
,
batchSize
);
TV_CHECK_CUDA_ERR_V2
(
"assignIndiceOutKernel failed"
);
auto
tableSize
=
table
.
get_table_size
();
auto
tableData
=
table
.
data
();
auto
constants
=
table
.
get_constants_4
();
...
...
@@ -133,6 +140,7 @@ int create_conv_indice_pair_p2_cuda(
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
tableSize
,
tableData
,
constants
,
stash_constants
,
stash_count
);
TV_CHECK_CUDA_ERR_V2
(
"assignIndicePairsHashKernel failed"
);
}
else
{
assignGridAndIndiceOutKernel
<
Index
,
IndexGrid
,
NDim
>
...
...
@@ -145,7 +153,7 @@ int create_conv_indice_pair_p2_cuda(
assignIndicePairsKernel
<
Index
,
IndexGrid
,
NDim
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesOut
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
numAct
,
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
numAct
In
,
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indicePairUnique
),
ou
);
TV_CHECK_CUDA_ERR_V2
(
"assignIndicePairsKernel failed"
);
...
...
@@ -177,11 +185,11 @@ int create_submconv_indice_pair_cuda(
auto
kernelVolume
=
indicePairs
.
size
(
0
);
if
(
numActIn
==
0
)
return
0
;
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
V
)
{
using
Index
=
decltype
(
V
);
tv
::
dispatch_torch
<
int32_t
>
(
indicesIn
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
using
IndexGrid
=
int32_t
;
tv
::
dispatch_int
<
2
,
3
,
4
>
(
ndim
,
[
&
](
auto
I
)
{
constexpr
int
NDim
=
I
;
constexpr
int
NDim
=
decltype
(
I
)
::
value
;
tv
::
SimpleVector
<
Index
,
NDim
>
ks
(
kernelSize
.
begin
(),
kernelSize
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
st
(
stride
.
begin
(),
stride
.
end
());
tv
::
SimpleVector
<
Index
,
NDim
>
pa
(
padding
.
begin
(),
padding
.
end
());
...
...
@@ -214,26 +222,36 @@ int create_submconv_indice_pair_cuda(
auto
constants
=
table
.
get_constants_4
();
auto
stash_constants
=
table
.
get_stash_constants
();
auto
stash_count
=
table
.
get_stash_count
();
getSubMIndicePairsHashKernel
<
Index
,
NDim
,
4096
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
,
st
,
pa
,
di
,
ou
,
tableSize
,
tableData
,
constants
,
stash_constants
,
stash_count
);
tv
::
dispatch_int
<
16
,
32
,
256
,
4096
>
(
kernelVolume
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
constexpr
int
MaxKernelVolume
=
decltype
(
I2
)
::
value
;
getSubMIndicePairsHashKernel
<
Index
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
,
st
,
pa
,
di
,
ou
,
tableSize
,
tableData
,
constants
,
stash_constants
,
stash_count
);
TV_CHECK_CUDA_ERR_V2
(
"getSubMIndicePairsHashKernel failed"
);
});
}
else
{
prepareSubMGridKernel
<
Index
,
IndexGrid
,
NDim
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
ou
);
TV_CHECK_CUDA_ERR_V2
(
"prepareSubMGridKernel failed"
);
getSubMIndicePairsKernel
<
Index
,
IndexGrid
,
NDim
,
4096
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
,
st
,
pa
,
di
,
ou
);
TV_CHECK_CUDA_ERR_V2
(
"assignIndicePairsKernel failed"
);
tv
::
dispatch_int
<
16
,
32
,
256
,
4096
>
(
ndim
,
std
::
less_equal
<
int
>
(),
[
&
](
auto
I2
)
{
constexpr
int
MaxKernelVolume
=
decltype
(
I2
)
::
value
;
getSubMIndicePairsKernel
<
Index
,
IndexGrid
,
NDim
,
MaxKernelVolume
>
<<<
tv
::
cuda
::
getBlocks
(
numActIn
),
tv
::
cuda
::
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
tv
::
torch2tv
<
Index
>
(
indicesIn
),
tv
::
torch2tv
<
IndexGrid
>
(
gridsOut
),
tv
::
torch2tv
<
Index
>
(
indicePairs
),
tv
::
torch2tv
<
Index
>
(
indiceNum
),
ks
,
st
,
pa
,
di
,
ou
);
TV_CHECK_CUDA_ERR_V2
(
"assignIndicePairsKernel failed"
);
});
}
if
(
resetGrid
&&
(
!
useHash
))
{
...
...
src/spconv/reordering.cc
View file @
6c767a51
...
...
@@ -14,59 +14,60 @@
#include <ATen/Parallel.h>
#include <spconv/reordering.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
namespace
spconv
{
namespace
functor
{
template
<
typename
T
,
typename
Index
>
struct
SparseGatherFunctor
<
tv
::
CPU
,
T
,
Index
>
{
void
operator
()(
const
tv
::
CPU
&
d
,
tv
::
TensorView
<
T
>
buffer
,
tv
::
TensorView
<
const
T
>
features
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
)
{
int
numPlanes
=
features
.
dim
(
1
);
at
::
parallel_for
(
0
,
size
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int
i
=
begin
;
i
<
end
;
++
i
)
{
std
::
memcpy
(
buffer
.
data
()
+
i
*
numPlanes
,
features
.
data
()
+
indices
[
i
]
*
numPlanes
,
sizeof
(
T
)
*
numPlanes
);
}
});
}
};
template
<
typename
T
,
typename
Index
>
struct
SparseScatterAddFunctor
<
tv
::
CPU
,
T
,
Index
>
{
void
operator
()(
const
tv
::
CPU
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
buffer
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
,
bool
stable
)
{
int
numPlanes
=
outFeatures
.
dim
(
1
);
const
T
*
buf
=
buffer
.
data
();
T
*
out
=
outFeatures
.
data
();
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
buf
=
buffer
.
data
()
+
i
*
numPlanes
;
out
=
outFeatures
.
data
()
+
indices
[
i
]
*
numPlanes
;
for
(
int
j
=
0
;
j
<
numPlanes
;
++
j
)
{
out
[
j
]
+=
buf
[
j
];
}
}
}
};
}
// namespace functor
using
float_types_t
=
tv
::
mp_list
<
float
,
double
,
at
::
Half
>
;
using
int_types_t
=
tv
::
mp_list
<
int32_t
,
int64_t
>
;
#define DECLARE_CPU_SPECS_T_INDEX(T, Index) \
template struct functor::SparseGatherFunctor<tv::CPU, T, Index>; \
template struct functor::SparseScatterAddFunctor<tv::CPU, T, Index>;
#define DECLARE_CPU_SPECS(T) \
DECLARE_CPU_SPECS_T_INDEX(T, int); \
DECLARE_CPU_SPECS_T_INDEX(T, long);
void
sparse_gather_cpu
(
torch
::
Tensor
buffer
,
torch
::
Tensor
features
,
torch
::
Tensor
indices
,
int
size
)
{
int
numPlanes
=
features
.
size
(
1
);
auto
dtype
=
features
.
scalar_type
();
auto
int_dtype
=
indices
.
scalar_type
();
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
Index
*
indices_data
=
indices
.
data_ptr
<
Index
>
();
T
*
buffer_data
=
buffer
.
data_ptr
<
T
>
();
const
T
*
features_data
=
features
.
data_ptr
<
T
>
();
at
::
parallel_for
(
0
,
size
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int
i
=
begin
;
i
<
end
;
++
i
)
{
std
::
memcpy
(
buffer_data
+
i
*
numPlanes
,
features_data
+
indices_data
[
i
]
*
numPlanes
,
sizeof
(
T
)
*
numPlanes
);
}
});
});
});
}
DECLARE_CPU_SPECS
(
float
);
DECLARE_CPU_SPECS
(
double
);
DECLARE_CPU_SPECS
(
at
::
Half
);
void
sparse_scatter_add_cpu
(
torch
::
Tensor
buffer
,
torch
::
Tensor
outFeatures
,
torch
::
Tensor
indices
,
int
size
)
{
int
numPlanes
=
outFeatures
.
size
(
1
);
auto
dtype
=
outFeatures
.
scalar_type
();
auto
int_dtype
=
indices
.
scalar_type
();
#undef DECLARE_CPU_SPECS
#undef DECLARE_CPU_SPECS_T_INDEX
tv
::
DispatchTorch
<
float_types_t
>
()(
dtype
,
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
tv
::
DispatchTorch
<
int_types_t
>
()(
int_dtype
,
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
Index
*
indices_data
=
indices
.
data_ptr
<
Index
>
();
const
T
*
buffer_data
=
buffer
.
data_ptr
<
T
>
();
T
*
features_data
=
outFeatures
.
data_ptr
<
T
>
();
const
T
*
buf
=
buffer
.
data_ptr
<
T
>
();
T
*
out
=
outFeatures
.
data_ptr
<
T
>
();
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
buf
=
buffer_data
+
i
*
numPlanes
;
out
=
features_data
+
indices_data
[
i
]
*
numPlanes
;
for
(
int
j
=
0
;
j
<
numPlanes
;
++
j
)
{
out
[
j
]
+=
buf
[
j
];
}
}
});
});
}
}
// namespace spconv
src/spconv/reordering.cu
View file @
6c767a51
...
...
@@ -20,137 +20,148 @@
#include <tensorview/cuda_utils.h>
#include <tensorview/kernel_utils.h>
#include <tensorview/mp_helper.h>
#include <tensorview/tensor.h>
#include <tensorview/tensorview.h>
#include <tensorview/torch_utils.h>
#include <type_traits>
#include <utility/timer.h>
namespace
spconv
{
namespace
functor
{
template
<
typename
T
,
typename
Index
>
struct
SparseGatherFunctor
<
tv
::
GPU
,
T
,
Index
>
{
using
vecload_type_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
void
operator
()(
const
tv
::
GPU
&
d
,
tv
::
TensorView
<
T
>
buffer
,
tv
::
TensorView
<
const
T
>
features
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
)
{
if
(
size
<=
0
)
return
;
int
numPlanes
=
features
.
dim
(
1
);
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
buffer
,
&
features
,
&
indices
,
&
notFound
](
auto
NumTLP
)
{
constexpr
int
NumILP
=
NumTLP
/
4
;
// constexpr int NumILP = NumTLP / (64 / (NumTLP / vecloadFactor));
int
nHotBlock
=
(
size
/
NumTLP
)
*
NumTLP
;
if
(
notFound
)
{
if
(
numPlanes
%
NumTLP
==
0
)
{
if
(
nHotBlock
>=
NumTLP
)
{
gatherVecBlockKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
,
vecload_type_t
>
<<<
dim3
(
numPlanes
/
NumTLP
,
size
/
NumTLP
),
dim3
(
NumTLP
/
vecloadFactor
,
NumTLP
/
NumILP
),
0
,
d
.
getStream
()
>>>
(
buffer
.
data
(),
features
.
data
(),
indices
.
data
(),
nHotBlock
,
numPlanes
/
vecloadFactor
);
TV_CHECK_CUDA_ERR
();
}
if
(
size
-
nHotBlock
>
0
)
{
gatherVecKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
,
vecload_type_t
>
<<<
dim3
(
1
,
numPlanes
/
NumTLP
),
dim3
(
NumTLP
/
NumILP
,
NumTLP
/
vecloadFactor
),
0
,
d
.
getStream
()
>>>
(
buffer
.
data
()
+
nHotBlock
*
numPlanes
,
features
.
data
(),
indices
.
data
()
+
nHotBlock
,
size
-
nHotBlock
,
numPlanes
/
vecloadFactor
);
TV_CHECK_CUDA_ERR
();
}
notFound
=
false
;
}
}
});
void
sparse_gather_cuda
(
torch
::
Tensor
buffer
,
torch
::
Tensor
features
,
torch
::
Tensor
indices
,
int
size
)
{
if
(
size
<=
0
)
return
;
int
numPlanes
=
features
.
size
(
1
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
notFound
)
{
constexpr
int
NumTLP
=
64
;
constexpr
int
NumILP
=
NumTLP
/
4
;
gatherGenericKernel
<
T
,
Index
,
NumTLP
,
NumILP
>
<<<
dim3
(
tv
::
cuda
::
DivUp
(
size
,
NumTLP
),
tv
::
cuda
::
DivUp
(
numPlanes
,
NumTLP
)),
dim3
(
NumTLP
/
NumILP
,
NumTLP
),
0
,
d
.
getStream
()
>>>
(
buffer
.
data
(),
features
.
data
(),
indices
.
data
(),
size
,
numPlanes
);
TV_CHECK_CUDA_ERR
();
}
}
};
template
<
typename
T
,
typename
Index
>
struct
SparseScatterAddFunctor
<
tv
::
GPU
,
T
,
Index
>
{
using
vecload_type_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
void
operator
()(
const
tv
::
GPU
&
d
,
tv
::
TensorView
<
T
>
outFeatures
,
tv
::
TensorView
<
const
T
>
buffer
,
tv
::
TensorView
<
const
Index
>
indices
,
int
size
,
bool
stable
)
{
if
(
size
<=
0
)
return
;
int
numPlanes
=
outFeatures
.
dim
(
1
);
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
// important for half.
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
d
,
&
outFeatures
,
&
buffer
,
&
indices
,
&
notFound
](
auto
NumTLP
)
{
// constexpr int NumILP = NumTLP / (64 / (NumTLP / vecloadFactor));
constexpr
int
NumILP
=
NumTLP
/
4
;
int
nHotBlock
=
(
size
/
NumTLP
)
*
NumTLP
;
if
(
notFound
)
{
if
(
numPlanes
%
NumTLP
==
0
)
{
if
(
nHotBlock
>=
NumTLP
)
{
scatterAddVecBlockKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
,
vecload_type_t
>
<<<
dim3
(
numPlanes
/
NumTLP
,
size
/
NumTLP
),
dim3
(
NumTLP
/
vecloadFactor
,
NumTLP
/
NumILP
),
0
,
d
.
getStream
()
>>>
(
outFeatures
.
data
(),
buffer
.
data
(),
indices
.
data
(),
nHotBlock
,
numPlanes
/
vecloadFactor
);
TV_CHECK_CUDA_ERR
();
}
if
(
size
-
nHotBlock
>
0
)
{
scatterAddGenericKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
>
<<<
dim3
(
1
,
numPlanes
/
NumTLP
),
dim3
(
NumTLP
/
NumILP
,
NumTLP
),
0
,
d
.
getStream
()
>>>
(
outFeatures
.
data
(),
buffer
.
data
()
+
nHotBlock
*
numPlanes
,
indices
.
data
()
+
nHotBlock
,
size
-
nHotBlock
,
numPlanes
);
tv
::
dispatch_torch
<
float
,
double
,
at
::
Half
>
(
features
.
scalar_type
(),
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
vecload_type_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indices
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
tv
::
mp_for_each
<
kernel_block_t
>
([
=
,
&
buffer
,
&
features
,
&
indices
,
&
notFound
](
auto
NumTLP
)
{
constexpr
int
NumILP
=
NumTLP
/
4
;
// constexpr int NumILP = NumTLP / (64 / (NumTLP / vecloadFactor));
int
nHotBlock
=
(
size
/
NumTLP
)
*
NumTLP
;
if
(
notFound
)
{
if
(
numPlanes
%
NumTLP
==
0
)
{
if
(
nHotBlock
>=
NumTLP
)
{
gatherVecBlockKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
,
vecload_type_t
>
<<<
dim3
(
numPlanes
/
NumTLP
,
size
/
NumTLP
),
dim3
(
NumTLP
/
vecloadFactor
,
NumTLP
/
NumILP
),
0
,
stream
>>>
(
buffer
.
data_ptr
<
T
>
(),
features
.
data_ptr
<
T
>
(),
indices
.
data_ptr
<
Index
>
(),
nHotBlock
,
numPlanes
/
vecloadFactor
);
TV_CHECK_CUDA_ERR
();
}
if
(
size
-
nHotBlock
>
0
)
{
gatherVecKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
,
vecload_type_t
>
<<<
dim3
(
1
,
numPlanes
/
NumTLP
),
dim3
(
NumTLP
/
NumILP
,
NumTLP
/
vecloadFactor
),
0
,
stream
>>>
(
buffer
.
data_ptr
<
T
>
()
+
nHotBlock
*
numPlanes
,
features
.
data_ptr
<
T
>
(),
indices
.
data_ptr
<
Index
>
()
+
nHotBlock
,
size
-
nHotBlock
,
numPlanes
/
vecloadFactor
);
TV_CHECK_CUDA_ERR
();
}
notFound
=
false
;
}
}
});
if
(
notFound
)
{
constexpr
int
NumTLP
=
64
;
constexpr
int
NumILP
=
NumTLP
/
4
;
gatherGenericKernel
<
T
,
Index
,
NumTLP
,
NumILP
>
<<<
dim3
(
tv
::
cuda
::
DivUp
(
size
,
NumTLP
),
tv
::
cuda
::
DivUp
(
numPlanes
,
NumTLP
)),
dim3
(
NumTLP
/
NumILP
,
NumTLP
),
0
,
stream
>>>
(
buffer
.
data_ptr
<
T
>
(),
features
.
data_ptr
<
T
>
(),
indices
.
data_ptr
<
Index
>
(),
size
,
numPlanes
);
TV_CHECK_CUDA_ERR
();
}
notFound
=
false
;
}
}
});
if
(
notFound
)
{
constexpr
int
NumTLP
=
64
;
constexpr
int
NumILP
=
NumTLP
/
4
;
scatterAddGenericKernel
<
T
,
Index
,
NumTLP
,
NumILP
>
<<<
dim3
(
tv
::
cuda
::
DivUp
(
size
,
NumTLP
),
tv
::
cuda
::
DivUp
(
numPlanes
,
NumTLP
)),
dim3
(
NumTLP
/
NumILP
,
NumTLP
),
0
,
d
.
getStream
()
>>>
(
outFeatures
.
data
(),
buffer
.
data
(),
indices
.
data
(),
size
,
numPlanes
);
TV_CHECK_CUDA_ERR
();
}
}
};
}
// namespace functor
});
});
}
#define DECLARE_GPU_SPECS_T_INDEX(T, Index) \
template struct functor::SparseGatherFunctor<tv::GPU, T, Index>; \
template struct functor::SparseScatterAddFunctor<tv::GPU, T, Index>;
void
sparse_scatter_add_cuda
(
torch
::
Tensor
buffer
,
torch
::
Tensor
outFeatures
,
torch
::
Tensor
indices
,
int
size
)
{
if
(
size
<=
0
)
return
;
int
numPlanes
=
outFeatures
.
size
(
1
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPECS_T_INDEX(T, int);
tv
::
dispatch_torch
<
float
,
double
,
at
::
Half
>
(
outFeatures
.
scalar_type
(),
[
&
](
auto
TValue
)
{
using
T
=
decltype
(
TValue
);
using
vecload_type_t
=
std
::
conditional_t
<
std
::
is_same
<
T
,
at
::
Half
>::
value
,
int2
,
int4
>
;
using
kernel_block_t
=
tv
::
mp_list_c
<
int
,
64
,
32
,
16
>
;
DECLARE_GPU_SPECS
(
float
);
DECLARE_GPU_SPECS
(
double
);
DECLARE_GPU_SPECS
(
at
::
Half
);
tv
::
dispatch_torch
<
int32_t
,
int64_t
>
(
indices
.
scalar_type
(),
[
&
](
auto
IndexValue
)
{
using
Index
=
decltype
(
IndexValue
);
bool
notFound
=
true
;
constexpr
int
vecloadFactor
=
sizeof
(
vecload_type_t
)
/
sizeof
(
T
);
// important for half.
tv
::
mp_for_each
<
kernel_block_t
>
(
[
=
,
&
outFeatures
,
&
buffer
,
&
indices
,
&
notFound
](
auto
NumTLP
)
{
// constexpr int NumILP = NumTLP / (64 / (NumTLP /
// vecloadFactor));
constexpr
int
NumILP
=
NumTLP
/
4
;
int
nHotBlock
=
(
size
/
NumTLP
)
*
NumTLP
;
if
(
notFound
)
{
if
(
numPlanes
%
NumTLP
==
0
)
{
if
(
nHotBlock
>=
NumTLP
)
{
scatterAddVecBlockKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
,
vecload_type_t
>
<<<
dim3
(
numPlanes
/
NumTLP
,
size
/
NumTLP
),
dim3
(
NumTLP
/
vecloadFactor
,
NumTLP
/
NumILP
),
0
,
stream
>>>
(
outFeatures
.
data_ptr
<
T
>
(),
buffer
.
data_ptr
<
T
>
(),
indices
.
data_ptr
<
Index
>
(),
nHotBlock
,
numPlanes
/
vecloadFactor
);
TV_CHECK_CUDA_ERR
();
}
if
(
size
-
nHotBlock
>
0
)
{
scatterAddGenericKernel
<
T
,
Index
,
int
(
NumTLP
),
NumILP
>
<<<
dim3
(
1
,
numPlanes
/
NumTLP
),
dim3
(
NumTLP
/
NumILP
,
NumTLP
),
0
,
stream
>>>
(
outFeatures
.
data_ptr
<
T
>
(),
buffer
.
data_ptr
<
T
>
()
+
nHotBlock
*
numPlanes
,
indices
.
data_ptr
<
Index
>
()
+
nHotBlock
,
size
-
nHotBlock
,
numPlanes
);
TV_CHECK_CUDA_ERR
();
}
notFound
=
false
;
}
}
});
if
(
notFound
)
{
constexpr
int
NumTLP
=
64
;
constexpr
int
NumILP
=
NumTLP
/
4
;
scatterAddGenericKernel
<
T
,
Index
,
NumTLP
,
NumILP
>
<<<
dim3
(
tv
::
cuda
::
DivUp
(
size
,
NumTLP
),
tv
::
cuda
::
DivUp
(
numPlanes
,
NumTLP
)),
dim3
(
NumTLP
/
NumILP
,
NumTLP
),
0
,
stream
>>>
(
outFeatures
.
data_ptr
<
T
>
(),
buffer
.
data_ptr
<
T
>
(),
indices
.
data_ptr
<
Index
>
(),
size
,
numPlanes
);
TV_CHECK_CUDA_ERR
();
}
});
});
}
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_T_INDEX
}
// namespace spconv
\ No newline at end of file
src/spconv/spconv_ops.cc
View file @
6c767a51
#include <spconv/spconv_ops.h>
namespace
spconv
{
std
::
vector
<
torch
::
Tensor
>
getIndicePairV2
(
torch
::
Tensor
indices
,
int64_t
batchSize
,
std
::
vector
<
int64_t
>
outSpatialShape
,
std
::
vector
<
int64_t
>
spatialShape
,
std
::
vector
<
int64_t
>
kernelSize
,
std
::
vector
<
int64_t
>
stride
,
std
::
vector
<
int64_t
>
padding
,
std
::
vector
<
int64_t
>
dilation
,
std
::
vector
<
int64_t
>
outPadding
,
int64_t
_subM
,
int64_t
_transpose
,
int64_t
_useHash
)
{
// auto timer = spconv::CudaContextTimer<>();
bool
subM
=
_subM
!=
0
;
bool
transpose
=
_transpose
!=
0
;
auto
NDim
=
kernelSize
.
size
();
// CPU always use hash (tsl::robin_map).
bool
useHash
=
_useHash
!=
0
||
indices
.
device
().
type
()
==
torch
::
kCPU
;
auto
numAct
=
indices
.
size
(
0
);
auto
coorDim
=
indices
.
size
(
1
)
-
1
;
// batchIdx + xyz
TV_ASSERT_RT_ERR
(
NDim
==
coorDim
,
"error"
);
TV_ASSERT_RT_ERR
(
kernelSize
.
size
()
==
coorDim
,
"error"
);
TV_ASSERT_RT_ERR
(
outSpatialShape
.
size
()
==
coorDim
,
"error"
);
TV_ASSERT_RT_ERR
(
stride
.
size
()
==
coorDim
,
"error"
);
TV_ASSERT_RT_ERR
(
padding
.
size
()
==
coorDim
,
"error"
);
TV_ASSERT_RT_ERR
(
outPadding
.
size
()
==
coorDim
,
"error"
);
TV_ASSERT_RT_ERR
(
dilation
.
size
()
==
coorDim
,
"error"
);
auto
kernelVolume
=
kernelSize
[
0
];
for
(
int
i
=
1
;
i
<
kernelSize
.
size
();
++
i
)
{
kernelVolume
*=
kernelSize
[
i
];
}
TV_ASSERT_RT_ERR
(
kernelVolume
<=
4096
,
"error"
);
auto
outputVolume
=
outSpatialShape
[
0
];
for
(
int
i
=
1
;
i
<
outSpatialShape
.
size
();
++
i
)
{
outputVolume
*=
outSpatialShape
[
i
];
}
std
::
string
msg
=
"due to limits of cuda hash, the volume of dense space "
"include batch size "
;
msg
+=
"must less than std::numeric_limits<int>::max() = 2e9"
;
TV_ASSERT_RT_ERR
(
batchSize
*
outputVolume
<
std
::
numeric_limits
<
int
>::
max
(),
msg
);
torch
::
Tensor
indicePairs
=
torch
::
full
({
kernelVolume
,
2
,
numAct
},
-
1
,
torch
::
dtype
(
torch
::
kInt32
).
device
(
indices
.
device
()));
torch
::
Tensor
indiceNum
=
torch
::
zeros
(
{
kernelVolume
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
indices
.
device
()));
auto
gridSize
=
batchSize
*
outputVolume
;
if
(
useHash
)
{
gridSize
=
batchSize
;
}
torch
::
Tensor
gridOut
=
torch
::
full
(
{
gridSize
},
-
1
,
torch
::
dtype
(
torch
::
kInt32
).
device
(
indices
.
device
()));
gridOut
=
gridOut
.
view
({
batchSize
,
-
1
});
int64_t
numActOut
=
-
1
;
for
(
int
i
=
0
;
i
<
NDim
;
++
i
)
{
if
(
subM
)
{
padding
[
i
]
=
kernelSize
[
i
]
/
2
;
stride
[
i
]
=
1
;
}
}
if
(
subM
)
{
if
(
indices
.
device
().
type
()
==
torch
::
kCPU
)
{
numActOut
=
create_submconv_indice_pair_cpu
(
indices
,
gridOut
,
indicePairs
,
indiceNum
,
kernelSize
,
stride
,
padding
,
dilation
,
outSpatialShape
,
transpose
,
false
,
useHash
);
}
#ifdef TV_CUDA
else
if
(
indices
.
device
().
type
()
==
torch
::
kCUDA
)
{
numActOut
=
create_submconv_indice_pair_cuda
(
indices
,
gridOut
,
indicePairs
,
indiceNum
,
kernelSize
,
stride
,
padding
,
dilation
,
outSpatialShape
,
transpose
,
false
,
useHash
);
}
#endif
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
return
{
indices
,
indicePairs
,
indiceNum
};
}
else
{
auto
indicePairUnique
=
torch
::
full
(
{
indicePairs
.
numel
()
/
2
+
1
},
std
::
numeric_limits
<
int
>::
max
(),
torch
::
dtype
(
torch
::
kInt32
).
device
(
indices
.
device
()));
torch
::
Tensor
outInds
=
torch
::
zeros
({
numAct
*
kernelVolume
,
coorDim
+
1
},
torch
::
dtype
(
torch
::
kInt32
).
device
(
indices
.
device
()));
if
(
indices
.
device
().
type
()
==
torch
::
kCPU
)
{
numActOut
=
create_conv_indice_pair_cpu
(
indices
,
outInds
,
gridOut
,
indicePairs
,
indiceNum
,
kernelSize
,
stride
,
padding
,
dilation
,
outSpatialShape
,
transpose
,
false
,
useHash
);
}
#ifdef TV_CUDA
else
if
(
indices
.
device
().
type
()
==
torch
::
kCUDA
)
{
numActOut
=
create_conv_indice_pair_p1_cuda
(
indices
,
indicePairs
,
indiceNum
,
indicePairUnique
,
kernelSize
,
stride
,
padding
,
dilation
,
outSpatialShape
,
transpose
);
if
(
numActOut
>
0
)
{
auto
res
=
torch
::
_unique
(
indicePairUnique
);
indicePairUnique
=
std
::
get
<
0
>
(
res
);
numActOut
=
create_conv_indice_pair_p2_cuda
(
indices
,
outInds
,
gridOut
,
indicePairs
,
indiceNum
,
indicePairUnique
,
outSpatialShape
,
transpose
,
false
,
useHash
);
}
}
#endif
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
return
{
outInds
.
slice
(
0
,
0
,
numActOut
),
indicePairs
,
indiceNum
};
}
}
torch
::
Tensor
indiceConv
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
indicePairs
,
torch
::
Tensor
indiceNum
,
int64_t
numActOut
,
int64_t
_inverse
,
int64_t
_subM
)
{
...
...
@@ -47,81 +153,59 @@ torch::Tensor indiceConv(torch::Tensor features, torch::Tensor filters,
double
totalGatherTime
=
0
;
double
totalGEMMTime
=
0
;
double
totalSAddTime
=
0
;
tv
::
dispatch_torch
<
float
,
double
,
at
::
Half
>
(
features
.
scalar_type
(),
[
&
](
auto
I
)
{
using
T
=
decltype
(
I
);
for
(
int
i
=
0
;
i
<
kernelVolume
;
++
i
)
{
auto
nHot
=
indicePairNumCpu
.
data_ptr
<
int
>
()[
i
];
if
(
nHot
<=
0
||
(
subM
&&
i
==
indicePairMaxOffset
))
{
continue
;
}
// auto timer = spconv::CudaContextTimer<>();
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numInPlanes
},
options
);
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseGatherFunctor
<
tv
::
CPU
,
T
,
int
>
gatherFtor
;
gatherFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
}
for
(
int
i
=
0
;
i
<
kernelVolume
;
++
i
)
{
auto
nHot
=
indicePairNumCpu
.
data_ptr
<
int
>
()[
i
];
if
(
nHot
<=
0
||
(
subM
&&
i
==
indicePairMaxOffset
))
{
continue
;
}
// auto timer = spconv::CudaContextTimer<>();
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
(),
{
nHot
,
numInPlanes
},
options
);
if
(
device
==
torch
::
kCPU
)
{
sparse_gather_cpu
(
inputBuffer
,
features
,
indicePairs
[
i
][
inverse
],
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseGatherFunctor
<
tv
::
GPU
,
T
,
int
>
gatherFtor
;
gatherFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
TV_CHECK_CUDA_ERR
();
/* slower than SparseGatherFunctor, may due to int->long conversion
auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64);
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(),
{nHot}, indicePairOptions); torch::index_select_out(inputBufferBlob,
features, 0, indicePairBlob);*/
}
else
if
(
device
==
torch
::
kCUDA
)
{
sparse_gather_cuda
(
inputBuffer
,
features
,
indicePairs
[
i
][
inverse
],
nHot
);
/* slower than SparseGatherFunctor, may due to int->long conversion
auto indicePairLong = indicePairs[i][inverse].to(torch::kInt64);
auto indicePairBlob = torch::from_blob(indicePairLong.data<long>(),
{nHot}, indicePairOptions); torch::index_select_out(inputBufferBlob,
features, 0, indicePairBlob);*/
}
#endif
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
// totalGatherTime += timer.report() / 1000.0;
torch
::
mm_out
(
outputBufferBlob
,
inputBufferBlob
,
filters
[
i
]);
// totalGEMMTime += timer.report() / 1000.0;
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
CPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
output
),
tv
::
torch2tv
<
const
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
}
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
// totalGatherTime += timer.report() / 1000.0;
torch
::
mm_out
(
outputBufferBlob
,
inputBufferBlob
,
filters
[
i
]);
// totalGEMMTime += timer.report() / 1000.0;
if
(
device
==
torch
::
kCPU
)
{
sparse_scatter_add_cpu
(
outputBuffer
,
output
,
indicePairs
[
i
][
!
inverse
],
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
GPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
output
),
tv
::
torch2tv
<
const
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
,
true
);
TV_CHECK_CUDA_ERR
();
}
else
if
(
device
==
torch
::
kCUDA
)
{
sparse_scatter_add_cuda
(
outputBuffer
,
output
,
indicePairs
[
i
][
!
inverse
],
nHot
);
}
#endif
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
// totalSAddTime += timer.report() / 1000.0;
}
});
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
// totalSAddTime += timer.report() / 1000.0;
}
// std::cout << "gather time " << totalGatherTime << std::endl;
// std::cout << "gemm time " << totalGEMMTime << std::endl;
// std::cout << "scatteradd time " << totalSAddTime << std::endl;
return
output
;
}
std
::
vector
<
torch
::
Tensor
>
indiceConvBackward
(
torch
::
Tensor
features
,
torch
::
Tensor
filters
,
torch
::
Tensor
outGrad
,
torch
::
Tensor
indicePairs
,
...
...
@@ -158,77 +242,47 @@ indiceConvBackward(torch::Tensor features, torch::Tensor filters,
torch
::
mm_out
(
filterGradSub
,
features
.
t
(),
outGrad
);
torch
::
mm_out
(
inputGrad
,
outGrad
,
filters
[
indicePairMaxOffset
].
t
());
}
tv
::
dispatch_torch
<
float
,
double
,
at
::
Half
>
(
features
.
scalar_type
(),
[
&
](
auto
I
)
{
using
T
=
decltype
(
I
);
for
(
int
i
=
0
;
i
<
kernelVolume
;
++
i
)
{
auto
nHot
=
indicePairNumCpu
.
data_ptr
<
int
>
()[
i
];
if
(
nHot
<=
0
||
(
subM
&&
i
==
indicePairMaxOffset
))
{
continue
;
}
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseGatherFunctor
<
tv
::
CPU
,
T
,
int
>
gatherFtor
;
functor
::
SparseGatherFunctor
<
tv
::
CPU
,
T
,
int
>
gatherFtorOut
;
gatherFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
gatherFtorOut
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
T
>
(
outGrad
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
);
}
for
(
int
i
=
0
;
i
<
kernelVolume
;
++
i
)
{
auto
nHot
=
indicePairNumCpu
.
data_ptr
<
int
>
()[
i
];
if
(
nHot
<=
0
||
(
subM
&&
i
==
indicePairMaxOffset
))
{
continue
;
}
if
(
device
==
torch
::
kCPU
)
{
sparse_gather_cpu
(
inputBuffer
,
features
,
indicePairs
[
i
][
inverse
],
nHot
);
sparse_gather_cpu
(
outputBuffer
,
outGrad
,
indicePairs
[
i
][
!
inverse
],
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseGatherFunctor
<
tv
::
GPU
,
T
,
int
>
gatherFtor
;
functor
::
SparseGatherFunctor
<
tv
::
GPU
,
T
,
int
>
gatherFtorOut
;
gatherFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
T
>
(
features
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
TV_CHECK_CUDA_ERR
();
gatherFtorOut
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
outputBuffer
),
tv
::
torch2tv
<
const
T
>
(
outGrad
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
!
inverse
),
nHot
);
TV_CHECK_CUDA_ERR
();
}
else
if
(
device
==
torch
::
kCUDA
)
{
sparse_gather_cuda
(
inputBuffer
,
features
,
indicePairs
[
i
][
inverse
],
nHot
);
sparse_gather_cuda
(
outputBuffer
,
outGrad
,
indicePairs
[
i
][
!
inverse
],
nHot
);
}
#endif
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
auto
filterGradSub
=
filtersGrad
[
i
];
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
<
T
>
(),
{
nHot
,
numInPlanes
},
options
);
torch
::
mm_out
(
filterGradSub
,
inputBufferBlob
.
t
(),
outputBufferBlob
);
torch
::
mm_out
(
inputBufferBlob
,
outputBufferBlob
,
filters
[
i
].
t
());
if
(
device
==
torch
::
kCPU
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
CPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
CPU
(),
tv
::
torch2tv
<
T
>
(
inputGrad
),
tv
::
torch2tv
<
const
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
}
auto
filterGradSub
=
filtersGrad
[
i
];
auto
outputBufferBlob
=
torch
::
from_blob
(
outputBuffer
.
data_ptr
(),
{
nHot
,
numOutPlanes
},
options
);
auto
inputBufferBlob
=
torch
::
from_blob
(
inputBuffer
.
data_ptr
(),
{
nHot
,
numInPlanes
},
options
);
torch
::
mm_out
(
filterGradSub
,
inputBufferBlob
.
t
(),
outputBufferBlob
);
torch
::
mm_out
(
inputBufferBlob
,
outputBufferBlob
,
filters
[
i
].
t
());
if
(
device
==
torch
::
kCPU
)
{
sparse_scatter_add_cpu
(
inputBuffer
,
inputGrad
,
indicePairs
[
i
][
inverse
],
nHot
);
}
#ifdef TV_CUDA
else
if
(
device
==
torch
::
kCUDA
)
{
functor
::
SparseScatterAddFunctor
<
tv
::
GPU
,
T
,
int
>
scatterFtor
;
scatterFtor
(
tv
::
TorchGPU
(),
tv
::
torch2tv
<
T
>
(
inputGrad
),
tv
::
torch2tv
<
const
T
>
(
inputBuffer
),
tv
::
torch2tv
<
const
int
>
(
indicePairs
).
subview
(
i
,
inverse
),
nHot
);
TV_CHECK_CUDA_ERR
();
}
else
if
(
device
==
torch
::
kCUDA
)
{
sparse_scatter_add_cuda
(
inputBuffer
,
inputGrad
,
indicePairs
[
i
][
inverse
],
nHot
);
}
#endif
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
else
{
TV_ASSERT_INVALID_ARG
(
false
,
"unknown device type"
);
}
}
);
}
return
{
inputGrad
,
filtersGrad
.
view
(
filterShape
)};
}
}
// namespace spconv
\ No newline at end of file
test/test_conv.py
View file @
6c767a51
...
...
@@ -392,7 +392,7 @@ class TestSpConv(TestCase):
def
testSpDeConv3d
(
self
):
np
.
random
.
seed
(
484
)
devices
=
[
"cuda:0"
,
"cpu:0"
]
devices
=
[
"cuda:0"
]
shapes
=
[[
19
,
18
,
17
]]
batchsizes
=
[
1
,
2
]
...
...
@@ -598,9 +598,9 @@ def main():
shapes
=
[[
50
,
30
,
30
]]
batchsizes
=
[
2
]
in_channels
=
[
2
56
]
out_channels
=
[
25
6
]
ksizes
=
[(
3
,
1
,
1
)]
in_channels
=
[
3
2
]
out_channels
=
[
6
4
]
ksizes
=
[(
3
,
3
,
3
)]
strides
=
[
1
]
paddings
=
[
0
]
dilations
=
[
1
]
...
...
@@ -654,5 +654,6 @@ def main():
if
__name__
==
'__main__'
:
main
()
#
main()
# unittest.main()
TestSpConv
().
testSpDeConv3d
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment