Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
fc650310
Commit
fc650310
authored
Jan 24, 2020
by
rusty1s
Browse files
spspmm kernel
parent
335dfed0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
401 additions
and
165 deletions
+401
-165
cuda/spspmm.cpp
cuda/spspmm.cpp
+51
-27
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+266
-134
test/test_matmul.py
test/test_matmul.py
+16
-0
torch_sparse/matmul.py
torch_sparse/matmul.py
+68
-4
No files found.
cuda/spspmm.cpp
View file @
fc650310
...
@@ -2,37 +2,61 @@
...
@@ -2,37 +2,61 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spspmm_cuda
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
spspmm_cuda
(
at
::
Tensor
rowptrA
,
at
::
Tensor
colA
,
at
::
Tensor
valueB
,
size_t
m
,
size_t
k
,
size_t
n
);
at
::
optional
<
at
::
Tensor
>
valueA
,
at
::
Tensor
rowptrB
,
at
::
Tensor
spspmm_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
at
::
Tensor
colB
,
at
::
optional
<
at
::
Tensor
>
valueB
,
int
M
,
int
N
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
int
K
);
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
spspmm
(
at
::
Tensor
rowptrA
,
at
::
Tensor
colA
,
at
::
optional
<
at
::
Tensor
>
valueA
,
size_t
m
,
size_t
k
,
size_t
n
)
{
at
::
Tensor
rowptrB
,
at
::
Tensor
colB
,
at
::
optional
<
at
::
Tensor
>
valueB
,
CHECK_CUDA
(
indexA
);
int
M
,
int
N
,
int
K
)
{
CHECK_CUDA
(
valueA
);
CHECK_CUDA
(
rowptrA
);
CHECK_CUDA
(
indexB
);
CHECK_CUDA
(
colA
);
CHECK_CUDA
(
valueB
);
if
(
valueA
.
has_value
())
return
spspmm_cuda
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
);
CHECK_CUDA
(
valueA
.
value
());
CHECK_CUDA
(
rowptrB
);
CHECK_CUDA
(
colB
);
if
(
valueB
.
has_value
())
CHECK_CUDA
(
valueB
.
value
());
return
spspmm_cuda
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
);
}
}
at
::
Tensor
spspmm_bw
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
// std::tuple<at::Tensor, at::Tensor>
at
::
Tensor
indexB
,
at
::
Tensor
valueB
,
size_t
rowA_max
,
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
size_t
rowB_max
)
{
// at::Tensor valueB, size_t m, size_t k, size_t n);
CHECK_CUDA
(
index
);
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
CHECK_CUDA
(
indexA
);
// at::Tensor valueA, at::Tensor indexB,
CHECK_CUDA
(
valueA
);
// at::Tensor valueB, size_t rowA_max, size_t
CHECK_CUDA
(
indexB
);
// rowB_max);
CHECK_CUDA
(
valueB
);
return
spspmm_bw_cuda
(
index
,
indexA
,
valueA
,
indexB
,
valueB
,
rowA_max
,
// std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor
rowB_max
);
// valueA,
}
// at::Tensor indexB, at::Tensor
// valueB, size_t m, size_t k, size_t
// n) {
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
// }
// at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
// at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
// size_t rowB_max) {
// CHECK_CUDA(index);
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
// rowB_max);
// }
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spspmm"
,
&
spspmm
,
"Sparse-Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spspmm"
,
&
spspmm
,
"Sparse-Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spspmm_bw"
,
&
spspmm_bw
,
//
m.def("spspmm_bw", &spspmm_bw,
"Sparse-Sparse Matrix Multiplication Backward (CUDA)"
);
//
"Sparse-Sparse Matrix Multiplication Backward (CUDA)");
}
}
cuda/spspmm_kernel.cu
View file @
fc650310
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include <cusparse.h>
#include "compat.cuh"
#include "compat.cuh"
#define THREADS 1024
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
#define BLOCKS(N) (N + THREADS - 1) / THREADS
[&] { \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
#define
CSRGEMM(TYPE, ...)
\
#define
AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...)
\
[&] { \
[&] { \
const auto &the_type = TYPE; \
switch (TYPE) { \
(void)the_type; \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
case at::ScalarType::Float: { \
case at::ScalarType::Float: { \
using scalar_t = float; \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
} \
case at::ScalarType::Double: { \
case at::ScalarType::Double: { \
using scalar_t = double; \
using scalar_t = double; \
return cusparseDcsrgemm(__VA_ARGS__); \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
} \
default: \
default: \
AT_ERROR("Not implemented for '", toString(
_st
), "'");
\
AT_ERROR("Not implemented for '", toString(
TYPE
), "'"); \
} \
} \
}()
}()
static
cusparseHandle_t
cusparse_handle
=
0
;
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spspmm_cuda
(
at
::
Tensor
rowptrA
,
at
::
Tensor
colA
,
at
::
optional
<
at
::
Tensor
>
valueA
,
at
::
Tensor
rowptrB
,
at
::
Tensor
colB
,
at
::
optional
<
at
::
Tensor
>
valueB
,
int
M
,
int
N
,
int
K
)
{
cusparseMatDescr_t
descr
=
0
;
cusparseCreateMatDescr
(
&
descr
);
auto
handle
=
at
::
cuda
::
getCurrentCUDASparseHandle
();
static
void
init_cusparse
()
{
rowptrA
=
rowptrA
.
toType
(
at
::
kInt
),
colA
=
colA
.
toType
(
at
::
kInt
);
if
(
cusparse_handle
==
0
)
{
rowptrB
=
rowptrB
.
toType
(
at
::
kInt
),
colB
=
colB
.
toType
(
at
::
kInt
);
cusparseStatus_t
status
=
cusparseCreate
(
&
cusparse_handle
);
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
auto
rowptrA_data
=
rowptrA
.
DATA_PTR
<
int
>
(),
colA_data
=
colA
.
DATA_PTR
<
int
>
();
spspmm_cuda
(
at
::
Tensor
indexA
,
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
auto
rowptrB_data
=
rowptrB
.
DATA_PTR
<
int
>
(),
colB_data
=
colB
.
DATA_PTR
<
int
>
();
at
::
Tensor
valueB
,
size_t
m
,
size_t
k
,
size_t
n
)
{
cudaSetDevice
(
indexA
.
get_device
());
init_cusparse
();
indexA
=
indexA
.
contiguous
();
valueA
=
valueA
.
contiguous
();
indexB
=
indexB
.
contiguous
();
valueB
=
valueB
.
contiguous
();
auto
nnzA
=
valueA
.
size
(
0
);
auto
nnzB
=
valueB
.
size
(
0
);
indexA
=
indexA
.
toType
(
at
::
kInt
);
indexB
=
indexB
.
toType
(
at
::
kInt
);
// Convert A to CSR format.
auto
row_ptrA
=
at
::
empty
(
m
+
1
,
indexA
.
options
());
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
DATA_PTR
<
int
>
(),
nnzA
,
k
,
row_ptrA
.
DATA_PTR
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colA
=
indexA
[
1
];
cudaMemcpy
(
row_ptrA
.
DATA_PTR
<
int
>
()
+
m
,
&
nnzA
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
// Convert B to CSR format.
auto
row_ptrB
=
at
::
empty
(
k
+
1
,
indexB
.
options
());
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
DATA_PTR
<
int
>
(),
nnzB
,
k
,
row_ptrB
.
DATA_PTR
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colB
=
indexB
[
1
];
cudaMemcpy
(
row_ptrB
.
DATA_PTR
<
int
>
()
+
k
,
&
nnzB
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
cusparseMatDescr_t
descr
=
0
;
csrgemm2Info_t
info
=
NULL
;
cusparseCreateMatDescr
(
&
descr
);
cusparseCreateCsrgemm2Info
(
&
info
);
cusparseSetMatType
(
descr
,
CUSPARSE_MATRIX_TYPE_GENERAL
);
cusparseSetMatIndexBase
(
descr
,
CUSPARSE_INDEX_BASE_ZERO
);
auto
scalar_type
=
at
::
ScalarType
::
Float
;
if
(
valueA
.
has_value
())
scalar_type
=
valueA
.
value
().
scalar_type
();
if
(
valueB
.
has_value
())
scalar_type
=
valueB
.
value
().
scalar_type
();
size_t
bufferSize
;
AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES
(
scalar_type
,
[
&
]
{
scalar_t
alpha
=
(
scalar_t
)
1
;
cusparsecsrgemm2_bufferSizeExt
(
handle
,
M
,
N
,
K
,
&
alpha
,
descr
,
colA
.
numel
(),
rowptrA_data
,
colA_data
,
descr
,
colB
.
numel
(),
rowptrB_data
,
colB_data
,
NULL
,
descr
,
0
,
NULL
,
NULL
,
info
,
&
bufferSize
);
});
void
*
buffer
=
NULL
;
cudaMalloc
(
&
buffer
,
bufferSize
);
int
nnzC
;
int
nnzC
;
auto
row_ptrC
=
at
::
empty
(
m
+
1
,
indexB
.
options
());
auto
rowptrC
=
at
::
empty
(
M
+
1
,
rowptrA
.
options
());
cusparseXcsrgemmNnz
(
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
auto
rowptrC_data
=
rowptrC
.
DATA_PTR
<
int
>
();
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
cusparseXcsrgemm2Nnz
(
handle
,
M
,
N
,
K
,
descr
,
colA
.
numel
(),
rowptrA_data
,
row_ptrA
.
DATA_PTR
<
int
>
(),
colA
.
DATA_PTR
<
int
>
(),
descr
,
colA_data
,
descr
,
colB
.
numel
(),
rowptrB_data
,
colB_data
,
nnzB
,
row_ptrB
.
DATA_PTR
<
int
>
(),
colB
.
DATA_PTR
<
int
>
(),
descr
,
0
,
NULL
,
NULL
,
descr
,
rowptrC_data
,
&
nnzC
,
info
,
descr
,
row_ptrC
.
DATA_PTR
<
int
>
(),
&
nnzC
);
buffer
);
auto
colC
=
at
::
empty
(
nnzC
,
indexA
.
options
());
auto
valueC
=
at
::
empty
(
nnzC
,
valueA
.
options
());
CSRGEMM
(
valueC
.
scalar_type
(),
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
valueA
.
DATA_PTR
<
scalar_t
>
(),
row_ptrA
.
DATA_PTR
<
int
>
(),
colA
.
DATA_PTR
<
int
>
(),
descr
,
nnzB
,
valueB
.
DATA_PTR
<
scalar_t
>
(),
row_ptrB
.
DATA_PTR
<
int
>
(),
colB
.
DATA_PTR
<
int
>
(),
descr
,
valueC
.
DATA_PTR
<
scalar_t
>
(),
row_ptrC
.
DATA_PTR
<
int
>
(),
colC
.
DATA_PTR
<
int
>
());
auto
rowC
=
at
::
empty
(
nnzC
,
indexA
.
options
());
cusparseXcsr2coo
(
cusparse_handle
,
row_ptrC
.
DATA_PTR
<
int
>
(),
nnzC
,
m
,
rowC
.
DATA_PTR
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
indexC
=
at
::
stack
({
rowC
,
colC
},
0
).
toType
(
at
::
kLong
);
return
std
::
make_tuple
(
indexC
,
valueC
);
}
at
::
Tensor
degree
(
at
::
Tensor
row
,
int64_t
num_nodes
)
{
auto
colC
=
at
::
empty
(
nnzC
,
colA
.
options
());
auto
zero
=
at
::
zeros
(
num_nodes
,
row
.
options
());
auto
colC_data
=
colC
.
DATA_PTR
<
int
>
();
auto
one
=
at
::
ones
(
row
.
size
(
0
),
row
.
options
());
return
zero
.
scatter_add_
(
0
,
row
,
one
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
to_csr
(
at
::
Tensor
row
,
at
::
Tensor
col
,
if
(
!
valueA
.
has_value
()
&&
valueB
.
has_value
())
int64_t
num_nodes
)
{
valueA
=
at
::
ones_like
(
valueB
.
value
());
// Assert already coalesced input.
row
=
degree
(
row
,
num_nodes
).
cumsum
(
0
);
row
=
at
::
cat
({
at
::
zeros
(
1
,
row
.
options
()),
row
},
0
);
// Prepend zero.
return
std
::
make_tuple
(
row
,
col
);
}
template
<
typename
scalar_t
>
if
(
!
valueB
.
has_value
()
&&
valueA
.
has_value
())
__global__
void
spspmm_bw_kernel
(
valueB
=
at
::
ones_like
(
valueA
.
value
());
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
value
,
const
int64_t
*
__restrict__
rowA
,
const
int64_t
*
__restrict__
colA
,
at
::
optional
<
at
::
Tensor
>
valueC
=
at
::
nullopt
;
const
scalar_t
*
__restrict__
valueA
,
const
int64_t
*
__restrict__
rowB
,
if
(
valueA
.
has_value
())
const
int64_t
*
__restrict__
colB
,
const
scalar_t
*
__restrict__
valueB
,
valueC
=
at
::
empty
(
nnzC
,
valueA
.
value
().
options
());
const
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
e
=
idx
;
e
<
numel
;
e
+=
stride
)
{
int64_t
i
=
index
[
e
],
j
=
index
[
numel
+
e
];
for
(
ptrdiff_t
dA
=
rowA
[
i
];
dA
<
rowA
[
i
+
1
];
dA
++
)
{
int64_t
cA
=
colA
[
dA
];
for
(
ptrdiff_t
dB
=
rowB
[
j
];
dB
<
rowB
[
j
+
1
];
dB
++
)
{
int64_t
cB
=
colB
[
dB
];
if
(
cA
==
cB
)
{
value
[
e
]
+=
valueA
[
dA
]
*
valueB
[
dB
];
}
if
(
cB
>=
cA
)
{
break
;
}
}
}
}
}
at
::
Tensor
spspmm_bw_cuda
(
at
::
Tensor
index
,
at
::
Tensor
indexA
,
AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES
(
scalar_type
,
[
&
]
{
at
::
Tensor
valueA
,
at
::
Tensor
indexB
,
scalar_t
alpha
=
(
scalar_t
)
1
;
at
::
Tensor
valueB
,
size_t
rowA_max
,
size_t
rowB_max
)
{
cudaSetDevice
(
index
.
get_device
());
auto
value
=
at
::
zeros
(
index
.
size
(
1
),
valueA
.
options
());
at
::
Tensor
rowA
,
colA
;
scalar_t
*
valueA_data
=
NULL
;
std
::
tie
(
rowA
,
colA
)
=
to_csr
(
indexA
[
0
],
indexA
[
1
],
rowA_max
);
if
(
valueA
.
has_value
())
valueA_data
=
valueA
.
value
().
DATA_PTR
<
scalar_t
>
();
at
::
Tensor
rowB
,
colB
;
scalar_t
*
valueB_data
=
NULL
;
std
::
tie
(
rowB
,
colB
)
=
to_csr
(
indexB
[
0
],
indexB
[
1
],
rowB_max
);
if
(
valueB
.
has_value
())
valueB_data
=
valueB
.
value
().
DATA_PTR
<
scalar_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
valueA
.
scalar_type
(),
"spspmm_bw"
,
[
&
]
{
scalar_t
*
valueC_data
=
NULL
;
spspmm_bw_kernel
<
scalar_t
><<<
BLOCKS
(
value
.
numel
()),
THREADS
>>>
(
if
(
valueC
.
has_value
())
index
.
DATA_PTR
<
int64_t
>
(),
value
.
DATA_PTR
<
scalar_t
>
(),
valueC_data
=
valueC
.
value
().
DATA_PTR
<
scalar_t
>
();
rowA
.
DATA_PTR
<
int64_t
>
(),
colA
.
DATA_PTR
<
int64_t
>
(),
valueA
.
DATA_PTR
<
scalar_t
>
(),
rowB
.
DATA_PTR
<
int64_t
>
(),
cusparsecsrgemm2
(
handle
,
M
,
N
,
K
,
&
alpha
,
descr
,
colA
.
numel
(),
valueA_data
,
colB
.
DATA_PTR
<
int64_t
>
(),
valueB
.
DATA_PTR
<
scalar_t
>
(),
value
.
numel
());
rowptrA_data
,
colA_data
,
descr
,
colB
.
numel
(),
valueB_data
,
rowptrB_data
,
colB_data
,
NULL
,
descr
,
0
,
NULL
,
NULL
,
NULL
,
descr
,
valueC_data
,
rowptrC_data
,
colC_data
,
info
,
buffer
);
});
});
return
value
;
auto
rowC
=
at
::
empty_like
(
colC
);
auto
rowC_data
=
rowC
.
DATA_PTR
<
int
>
();
cusparseXcsr2coo
(
handle
,
rowptrC_data
,
nnzC
,
M
,
rowC_data
,
CUSPARSE_INDEX_BASE_ZERO
);
cusparseDestroyCsrgemm2Info
(
info
);
auto
indexC
=
at
::
stack
({
rowC
.
toType
(
at
::
kLong
),
colC
.
toType
(
at
::
kLong
)},
0
);
return
std
::
make_tuple
(
indexC
,
rowptrC
.
toType
(
at
::
kLong
),
valueC
);
}
}
// #define THREADS 1024
// #define BLOCKS(N) (N + THREADS - 1) / THREADS
// #define CSRGEMM(TYPE, ...) \
// [&] { \
// const auto &the_type = TYPE; \
// (void)the_type; \
// at::ScalarType _st = ::detail::scalar_type(TYPE); \
// switch (_st) { \
// case at::ScalarType::Float: { \
// using scalar_t = float; \
// return cusparseScsrgemm(__VA_ARGS__); \
// } \
// case at::ScalarType::Double: { \
// using scalar_t = double; \
// return cusparseDcsrgemm(__VA_ARGS__); \
// } \
// default: \
// AT_ERROR("Not implemented for '", toString(_st), "'"); \
// } \
// }()
// static cusparseHandle_t cusparse_handle = 0;
// static void init_cusparse() {
// if (cusparse_handle == 0) {
// cusparseStatus_t status = cusparseCreate(&cusparse_handle);
// }
// }
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n) {
// cudaSetDevice(indexA.get_device());
// init_cusparse();
// indexA = indexA.contiguous();
// valueA = valueA.contiguous();
// indexB = indexB.contiguous();
// valueB = valueB.contiguous();
// auto nnzA = valueA.size(0);
// auto nnzB = valueB.size(0);
// indexA = indexA.toType(at::kInt);
// indexB = indexB.toType(at::kInt);
// // Convert A to CSR format.
// auto row_ptrA = at::empty(m + 1, indexA.options());
// cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
// row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colA = indexA[1];
// cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
// cudaMemcpyHostToDevice);
// // Convert B to CSR format.
// auto row_ptrB = at::empty(k + 1, indexB.options());
// cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
// row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colB = indexB[1];
// cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
// cudaMemcpyHostToDevice);
// cusparseMatDescr_t descr = 0;
// cusparseCreateMatDescr(&descr);
// cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
// cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
// int nnzC;
// auto row_ptrC = at::empty(m + 1, indexB.options());
// cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr,
// nnzA, row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB,
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
// descr, row_ptrC.DATA_PTR<int>(), &nnzC);
// auto colC = at::empty(nnzC, indexA.options());
// auto valueC = at::empty(nnzC, valueA.options());
// CSRGEMM(valueC.scalar_type(), cusparse_handle,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
// valueA.DATA_PTR<scalar_t>(), row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB, valueB.DATA_PTR<scalar_t>(),
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(), descr,
// valueC.DATA_PTR<scalar_t>(), row_ptrC.DATA_PTR<int>(),
// colC.DATA_PTR<int>());
// auto rowC = at::empty(nnzC, indexA.options());
// cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
// rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
// return std::make_tuple(indexC, valueC);
// }
// at::Tensor degree(at::Tensor row, int64_t num_nodes) {
// auto zero = at::zeros(num_nodes, row.options());
// auto one = at::ones(row.size(0), row.options());
// return zero.scatter_add_(0, row, one);
// }
// std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
// int64_t num_nodes) {
// // Assert already coalesced input.
// row = degree(row, num_nodes).cumsum(0);
// row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
// return std::make_tuple(row, col);
// }
// template <typename scalar_t>
// __global__ void spspmm_bw_kernel(
// const int64_t *__restrict__ index, scalar_t *__restrict__ value,
// const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
// const scalar_t *__restrict__ valueA, const int64_t *__restrict__
// rowB, const int64_t *__restrict__ colB, const scalar_t *__restrict__
// valueB, const size_t numel) {
// const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// const size_t stride = blockDim.x * gridDim.x;
// for (ptrdiff_t e = idx; e < numel; e += stride) {
// int64_t i = index[e], j = index[numel + e];
// for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
// int64_t cA = colA[dA];
// for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
// int64_t cB = colB[dB];
// if (cA == cB) {
// value[e] += valueA[dA] * valueB[dB];
// }
// if (cB >= cA) {
// break;
// }
// }
// }
// }
// }
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max) {
// cudaSetDevice(index.get_device());
// auto value = at::zeros(index.size(1), valueA.options());
// at::Tensor rowA, colA;
// std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
// at::Tensor rowB, colB;
// std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
// AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
// spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
// index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
// rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
// valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
// colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(),
// value.numel());
// });
// return value;
// }
test/test_matmul.py
View file @
fc650310
...
@@ -48,3 +48,19 @@ def test_spmm(dtype, device, reduce):
...
@@ -48,3 +48,19 @@ def test_spmm(dtype, device, reduce):
assert
torch
.
allclose
(
expected
,
out
)
assert
torch
.
allclose
(
expected
,
out
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
src
=
torch
.
tensor
([[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
]],
dtype
=
dtype
,
device
=
device
)
src
=
SparseTensor
.
from_dense
(
src
)
out
=
src
@
src
assert
out
.
size
()
==
(
3
,
3
)
assert
out
.
has_value
()
src
.
set_value_
(
None
)
out
=
src
@
src
assert
out
.
size
()
==
(
3
,
3
)
assert
not
out
.
has_value
()
torch_sparse/matmul.py
View file @
fc650310
import
torch
import
torch
import
scipy.sparse
from
torch_sparse
import
spmm_cpu
from
torch_sparse
import
spmm_cpu
from
torch_scatter
import
scatter_add
from
torch_scatter
import
scatter_add
...
@@ -8,6 +8,11 @@ try:
...
@@ -8,6 +8,11 @@ try:
except
ImportError
:
except
ImportError
:
spmm_cuda
=
None
spmm_cuda
=
None
try
:
from
torch_sparse
import
spspmm_cuda
except
ImportError
:
spspmm_cuda
=
None
def
spmm
(
is_cuda
):
def
spmm
(
is_cuda
):
return
spmm_cuda
if
is_cuda
else
spmm_cpu
return
spmm_cuda
if
is_cuda
else
spmm_cpu
...
@@ -61,10 +66,9 @@ class SPMM(torch.autograd.Function):
...
@@ -61,10 +66,9 @@ class SPMM(torch.autograd.Function):
grad_mat
=
None
grad_mat
=
None
if
ctx
.
needs_input_grad
[
6
]:
if
ctx
.
needs_input_grad
[
6
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
row
=
index
[
0
][
csr2csc
]
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
row
,
value
,
grad_out
,
'sum'
)
colptr
,
index
[
0
][
csr2csc
]
,
value
,
grad_out
,
'sum'
)
elif
ctx
.
reduce
==
'mean'
:
elif
ctx
.
reduce
==
'mean'
:
count
=
rowcount
[
index
[
0
]].
to
(
mat
.
dtype
).
clamp_
(
min
=
1
)
count
=
rowcount
[
index
[
0
]].
to
(
mat
.
dtype
).
clamp_
(
min
=
1
)
...
@@ -88,9 +92,61 @@ class SPMM(torch.autograd.Function):
...
@@ -88,9 +92,61 @@ class SPMM(torch.autograd.Function):
return
None
,
None
,
None
,
None
,
None
,
grad_value
,
grad_mat
,
None
return
None
,
None
,
None
,
None
,
None
,
grad_value
,
grad_mat
,
None
class
SPSPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
):
if
rowptrA
.
is_cuda
:
indexC
,
rowptrC
,
valueC
=
spspmm_cuda
.
spspmm
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
)
else
:
dtype
=
None
if
valueA
is
not
None
:
dtype
=
valueA
.
dtype
if
valueB
is
not
None
:
dtype
=
valueB
.
dtype
if
valueA
is
None
:
valueA
=
torch
.
ones
(
colA
.
numel
(),
dtype
=
dtype
)
A
=
scipy
.
sparse
.
csr_matrix
((
valueA
,
colA
,
rowptrA
),
(
M
,
N
))
if
valueB
is
None
:
valueB
=
torch
.
ones
(
colB
.
numel
(),
dtype
=
dtype
)
B
=
scipy
.
sparse
.
csr_matrix
((
valueB
,
colB
,
rowptrB
),
(
N
,
K
))
C
=
A
@
B
valueC
=
torch
.
from_numpy
(
C
.
data
).
to
(
dtype
)
if
dtype
is
not
None
else
None
rowptrC
=
torch
.
from_numpy
(
C
.
indptr
).
to
(
torch
.
int64
)
C
=
C
.
tocoo
()
rowC
=
torch
.
from_numpy
(
C
.
row
).
to
(
torch
.
int64
)
colC
=
torch
.
from_numpy
(
C
.
col
).
to
(
torch
.
int64
)
indexC
=
torch
.
stack
([
rowC
,
colC
],
dim
=
0
)
# We cannot return `NoneType` in torch.autograd :(
if
valueC
is
None
:
return
indexC
,
rowptrC
else
:
return
indexC
,
rowptrC
,
valueC
@
staticmethod
def
backward
(
ctx
,
grad_indexC
,
grad_rowptrC
,
*
args
):
grad_valueA
=
None
if
ctx
.
needs_input_grad
[
2
]:
raise
NotImplementedError
grad_valueB
=
None
if
ctx
.
needs_input_grad
[
5
]:
raise
NotImplementedError
return
(
None
,
None
,
grad_valueA
,
None
,
None
,
grad_valueB
,
None
,
None
,
None
)
def
matmul
(
src
,
other
,
reduce
=
'sum'
):
def
matmul
(
src
,
other
,
reduce
=
'sum'
):
assert
src
.
dim
()
==
2
and
src
.
size
(
-
1
)
==
other
.
size
(
-
2
)
assert
src
.
dim
()
==
2
and
src
.
size
(
-
1
)
==
other
.
size
(
-
2
)
# Sparse-Dense Matrix Multiplication.
if
torch
.
is_tensor
(
other
):
if
torch
.
is_tensor
(
other
):
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
(
index
,
value
),
rowptr
=
src
.
coo
(),
src
.
storage
.
rowptr
(
index
,
value
),
rowptr
=
src
.
coo
(),
src
.
storage
.
rowptr
...
@@ -106,8 +162,16 @@ def matmul(src, other, reduce='sum'):
...
@@ -106,8 +162,16 @@ def matmul(src, other, reduce='sum'):
return
SPMM
.
apply
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
return
SPMM
.
apply
(
index
,
rowcount
,
rowptr
,
colptr
,
csr2csc
,
value
,
other
,
reduce
)
other
,
reduce
)
# Sparse-Sparse Matrix Multiplication.
elif
isinstance
(
other
,
src
.
__class__
):
elif
isinstance
(
other
,
src
.
__class__
):
assert
reduce
in
[
'sum'
,
'add'
]
assert
reduce
in
[
'sum'
,
'add'
]
raise
NotImplementedError
assert
src
.
dim
()
==
2
and
other
.
dim
()
==
2
data
=
SPSPMM
.
apply
(
*
src
.
csr
(),
*
other
.
csr
(),
src
.
size
(
0
),
src
.
size
(
1
),
other
.
size
(
1
))
data
=
data
if
len
(
data
)
==
3
else
data
+
(
None
,
)
sparse_size
=
torch
.
Size
([
src
.
size
(
0
),
other
.
size
(
1
)])
out
=
src
.
__class__
(
data
[
0
],
data
[
2
],
sparse_size
,
is_sorted
=
True
)
out
.
storage
.
_rowptr
=
data
[
1
]
return
out
raise
ValueError
raise
ValueError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment