Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
c4484dbb
Commit
c4484dbb
authored
Jan 27, 2020
by
rusty1s
Browse files
jit support
parent
6e87043a
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
108 additions
and
315 deletions
+108
-315
cpu/diag.cpp
cpu/diag.cpp
+1
-1
cpu/spmm.cpp
cpu/spmm.cpp
+14
-15
cuda/diag_kernel.cu
cuda/diag_kernel.cu
+1
-1
cuda/spmm.cpp
cuda/spmm.cpp
+18
-17
cuda/spmm_kernel.cu
cuda/spmm_kernel.cu
+12
-11
cuda/spspmm.cpp
cuda/spspmm.cpp
+13
-47
cuda/spspmm_kernel.cu
cuda/spspmm_kernel.cu
+21
-187
torch_sparse/add.py
torch_sparse/add.py
+1
-4
torch_sparse/diag.py
torch_sparse/diag.py
+1
-1
torch_sparse/matmul.py
torch_sparse/matmul.py
+9
-23
torch_sparse/mul.py
torch_sparse/mul.py
+1
-4
torch_sparse/storage.py
torch_sparse/storage.py
+1
-1
torch_sparse/tensor.py
torch_sparse/tensor.py
+8
-3
torch_sparse/utils.py
torch_sparse/utils.py
+7
-0
No files found.
cpu/diag.cpp
View file @
c4484dbb
...
...
@@ -15,7 +15,7 @@ torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mask
=
torch
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
a
t
::
kBool
));
auto
mask
=
torch
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
t
orch
::
kBool
));
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
int64_t
r
,
c
;
...
...
cpu/spmm.cpp
View file @
c4484dbb
#include <torch/
extension
.h>
#include <torch/
script
.h>
#include "compat.h"
...
...
@@ -85,9 +85,10 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
};
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spmm
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
value_opt
,
torch
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
...
...
@@ -105,12 +106,12 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto
sizes
=
mat
.
sizes
().
vec
();
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
auto
out
=
a
t
::
empty
(
sizes
,
mat
.
options
());
auto
out
=
t
orch
::
empty
(
sizes
,
mat
.
options
());
a
t
::
optional
<
a
t
::
Tensor
>
arg_out
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
arg_out
=
t
orch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
a
t
::
full_like
(
out
,
col
.
numel
(),
rowptr
.
options
());
arg_out
=
t
orch
::
full_like
(
out
,
col
.
numel
(),
rowptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -174,8 +175,9 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
std
::
make_tuple
(
out
,
arg_out
);
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
row
,
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
torch
::
Tensor
spmm_val_bw
(
torch
::
Tensor
row
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mat
,
torch
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CPU
(
row
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
...
...
@@ -191,7 +193,7 @@ at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
out
=
a
t
::
zeros
(
row
.
numel
(),
grad
.
options
());
auto
out
=
t
orch
::
zeros
(
row
.
numel
(),
grad
.
options
());
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
DATA_PTR
<
int64_t
>
();
...
...
@@ -224,8 +226,5 @@ at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
return
out
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse-Dense Matrix Multiplication (CPU)"
);
m
.
def
(
"spmm_val_bw"
,
&
spmm_val_bw
,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_sparse_cpu::spmm"
,
&
spmm
)
.
op
(
"torch_sparse_cpu::spmm_val_bw"
,
&
spmm_val_bw
);
cuda/diag_kernel.cu
View file @
c4484dbb
...
...
@@ -46,7 +46,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
auto
row_data
=
row
.
DATA_PTR
<
int64_t
>
();
auto
col_data
=
col
.
DATA_PTR
<
int64_t
>
();
auto
mask
=
torch
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
a
t
::
kBool
));
auto
mask
=
torch
::
zeros
(
E
+
num_diag
,
row
.
options
().
dtype
(
t
orch
::
kBool
));
auto
mask_data
=
mask
.
DATA_PTR
<
bool
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
cuda/spmm.cpp
View file @
c4484dbb
#include <torch/
extension
.h>
#include <torch/
script
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spmm_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
value_opt
,
torch
::
Tensor
mat
,
std
::
string
reduce
);
a
t
::
Tensor
spmm_val_bw_cuda
(
a
t
::
Tensor
row
,
a
t
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
);
t
orch
::
Tensor
spmm_val_bw_cuda
(
t
orch
::
Tensor
row
,
t
orch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mat
,
torch
::
Tensor
grad
,
std
::
string
reduce
);
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spmm
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
value_opt
,
torch
::
Tensor
mat
,
std
::
string
reduce
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
if
(
value_opt
.
has_value
())
...
...
@@ -21,8 +23,9 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
return
spmm_cuda
(
rowptr
,
col
,
value_opt
,
mat
,
reduce
);
}
at
::
Tensor
spmm_val_bw
(
at
::
Tensor
row
,
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
torch
::
Tensor
spmm_val_bw
(
torch
::
Tensor
row
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mat
,
torch
::
Tensor
grad
,
std
::
string
reduce
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
...
...
@@ -31,8 +34,6 @@ at::Tensor spmm_val_bw(at::Tensor row, at::Tensor rowptr, at::Tensor col,
return
spmm_val_bw_cuda
(
row
,
rowptr
,
col
,
mat
,
grad
,
reduce
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spmm"
,
&
spmm
,
"Sparse Matrix Multiplication (CUDA)"
);
m
.
def
(
"spmm_val_bw"
,
&
spmm_val_bw
,
"Sparse-Dense Matrix Multiplication Value Backward (CPU)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_sparse_cuda::spmm"
,
&
spmm
)
.
op
(
"torch_sparse_cuda::spmm_val_bw"
,
&
spmm_val_bw
);
cuda/spmm_kernel.cu
View file @
c4484dbb
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "compat.cuh"
...
...
@@ -155,9 +155,10 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spmm_cuda
(
at
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
optional
<
at
::
Tensor
>
value_opt
,
at
::
Tensor
mat
,
std
::
string
reduce
)
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spmm_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
value_opt
,
torch
::
Tensor
mat
,
std
::
string
reduce
)
{
AT_ASSERTM
(
rowptr
.
dim
()
==
1
,
"Input mismatch"
);
AT_ASSERTM
(
col
.
dim
()
==
1
,
"Input mismatch"
);
...
...
@@ -169,12 +170,12 @@ spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto
sizes
=
mat
.
sizes
().
vec
();
sizes
[
mat
.
dim
()
-
2
]
=
rowptr
.
numel
()
-
1
;
auto
out
=
a
t
::
empty
(
sizes
,
mat
.
options
());
auto
out
=
t
orch
::
empty
(
sizes
,
mat
.
options
());
a
t
::
optional
<
a
t
::
Tensor
>
arg_out
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
arg_out
=
t
orch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
a
t
::
full_like
(
out
,
col
.
numel
(),
rowptr
.
options
());
arg_out
=
t
orch
::
full_like
(
out
,
col
.
numel
(),
rowptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -247,9 +248,9 @@ spmm_val_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
}
}
a
t
::
Tensor
spmm_val_bw_cuda
(
a
t
::
Tensor
row
,
a
t
::
Tensor
rowptr
,
at
::
Tensor
col
,
at
::
Tensor
mat
,
at
::
Tensor
grad
,
std
::
string
reduce
)
{
t
orch
::
Tensor
spmm_val_bw_cuda
(
t
orch
::
Tensor
row
,
t
orch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
mat
,
torch
::
Tensor
grad
,
std
::
string
reduce
)
{
mat
=
mat
.
contiguous
();
grad
=
grad
.
contiguous
();
...
...
@@ -261,7 +262,7 @@ at::Tensor spmm_val_bw_cuda(at::Tensor row, at::Tensor rowptr, at::Tensor col,
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
auto
BLOCKS
=
dim3
((
E
*
32
+
THREADS
-
1
)
/
THREADS
);
auto
out
=
a
t
::
zeros
(
row
.
numel
(),
grad
.
options
());
auto
out
=
t
orch
::
zeros
(
row
.
numel
(),
grad
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw_kernel"
,
[
&
]
{
...
...
cuda/spspmm.cpp
View file @
c4484dbb
#include <torch/
extension
.h>
#include <torch/
script
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
spspmm_cuda
(
a
t
::
Tensor
rowptrA
,
a
t
::
Tensor
colA
,
a
t
::
optional
<
a
t
::
Tensor
>
valueA
,
a
t
::
Tensor
rowptrB
,
a
t
::
Tensor
colB
,
a
t
::
optional
<
a
t
::
Tensor
>
valueB
,
int
M
,
int
N
,
int
K
);
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
spspmm_cuda
(
t
orch
::
Tensor
rowptrA
,
t
orch
::
Tensor
colA
,
t
orch
::
optional
<
t
orch
::
Tensor
>
valueA
,
t
orch
::
Tensor
rowptrB
,
t
orch
::
Tensor
colB
,
t
orch
::
optional
<
t
orch
::
Tensor
>
valueB
,
int
64_t
M
,
int64_t
N
,
int64_t
K
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
spspmm
(
at
::
Tensor
rowptrA
,
at
::
Tensor
colA
,
at
::
optional
<
at
::
Tensor
>
valueA
,
at
::
Tensor
rowptrB
,
at
::
Tensor
colB
,
at
::
optional
<
at
::
Tensor
>
valueB
,
int
M
,
int
N
,
int
K
)
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
valueA
,
torch
::
Tensor
rowptrB
,
torch
::
Tensor
colB
,
torch
::
optional
<
torch
::
Tensor
>
valueB
,
int64_t
M
,
int64_t
N
,
int64_t
K
)
{
CHECK_CUDA
(
rowptrA
);
CHECK_CUDA
(
colA
);
if
(
valueA
.
has_value
())
...
...
@@ -23,40 +24,5 @@ spspmm(at::Tensor rowptrA, at::Tensor colA, at::optional<at::Tensor> valueA,
return
spspmm_cuda
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
);
}
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n);
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max);
// std::tuple<at::Tensor, at::Tensor> spspmm(at::Tensor indexA, at::Tensor
// valueA,
// at::Tensor indexB, at::Tensor
// valueB, size_t m, size_t k, size_t
// n) {
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_cuda(indexA, valueA, indexB, valueB, m, k, n);
// }
// at::Tensor spspmm_bw(at::Tensor index, at::Tensor indexA, at::Tensor valueA,
// at::Tensor indexB, at::Tensor valueB, size_t rowA_max,
// size_t rowB_max) {
// CHECK_CUDA(index);
// CHECK_CUDA(indexA);
// CHECK_CUDA(valueA);
// CHECK_CUDA(indexB);
// CHECK_CUDA(valueB);
// return spspmm_bw_cuda(index, indexA, valueA, indexB, valueB, rowA_max,
// rowB_max);
// }
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"spspmm"
,
&
spspmm
,
"Sparse-Sparse Matrix Multiplication (CUDA)"
);
// m.def("spspmm_bw", &spspmm_bw,
// "Sparse-Sparse Matrix Multiplication Backward (CUDA)");
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_sparse_cuda::spspmm"
,
&
spspmm
);
cuda/spspmm_kernel.cu
View file @
c4484dbb
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cusparse.h>
...
...
@@ -8,13 +8,13 @@
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case
a
t::ScalarType::Float: {
\
case t
orch
::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
case
a
t::ScalarType::Double: {
\
case t
orch
::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
...
...
@@ -28,12 +28,12 @@
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case
a
t::ScalarType::Float: {
\
case t
orch
::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case
a
t::ScalarType::Double: {
\
case t
orch
::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
...
...
@@ -43,17 +43,17 @@
} \
}()
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
spspmm_cuda
(
a
t
::
Tensor
rowptrA
,
a
t
::
Tensor
colA
,
a
t
::
optional
<
a
t
::
Tensor
>
valueA
,
a
t
::
Tensor
rowptrB
,
a
t
::
Tensor
colB
,
a
t
::
optional
<
a
t
::
Tensor
>
valueB
,
int
M
,
int
N
,
int
K
)
{
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
spspmm_cuda
(
t
orch
::
Tensor
rowptrA
,
t
orch
::
Tensor
colA
,
t
orch
::
optional
<
t
orch
::
Tensor
>
valueA
,
t
orch
::
Tensor
rowptrB
,
t
orch
::
Tensor
colB
,
t
orch
::
optional
<
t
orch
::
Tensor
>
valueB
,
int
64_t
M
,
int64_t
N
,
int64_t
K
)
{
cusparseMatDescr_t
descr
=
0
;
cusparseCreateMatDescr
(
&
descr
);
auto
handle
=
at
::
cuda
::
getCurrentCUDASparseHandle
();
rowptrA
=
rowptrA
.
toType
(
a
t
::
kInt
),
colA
=
colA
.
toType
(
a
t
::
kInt
);
rowptrB
=
rowptrB
.
toType
(
a
t
::
kInt
),
colB
=
colB
.
toType
(
a
t
::
kInt
);
rowptrA
=
rowptrA
.
toType
(
t
orch
::
kInt
),
colA
=
colA
.
toType
(
t
orch
::
kInt
);
rowptrB
=
rowptrB
.
toType
(
t
orch
::
kInt
),
colB
=
colB
.
toType
(
t
orch
::
kInt
);
auto
rowptrA_data
=
rowptrA
.
DATA_PTR
<
int
>
(),
colA_data
=
colA
.
DATA_PTR
<
int
>
();
auto
rowptrB_data
=
rowptrB
.
DATA_PTR
<
int
>
(),
colB_data
=
colB
.
DATA_PTR
<
int
>
();
...
...
@@ -61,7 +61,7 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
csrgemm2Info_t
info
=
NULL
;
cusparseCreateCsrgemm2Info
(
&
info
);
auto
scalar_type
=
a
t
::
ScalarType
::
Float
;
auto
scalar_type
=
t
orch
::
ScalarType
::
Float
;
if
(
valueA
.
has_value
())
scalar_type
=
valueA
.
value
().
scalar_type
();
if
(
valueB
.
has_value
())
...
...
@@ -80,25 +80,25 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
cudaMalloc
(
&
buffer
,
bufferSize
);
int
nnzC
;
auto
rowptrC
=
a
t
::
empty
(
M
+
1
,
rowptrA
.
options
());
auto
rowptrC
=
t
orch
::
empty
(
M
+
1
,
rowptrA
.
options
());
auto
rowptrC_data
=
rowptrC
.
DATA_PTR
<
int
>
();
cusparseXcsrgemm2Nnz
(
handle
,
M
,
N
,
K
,
descr
,
colA
.
numel
(),
rowptrA_data
,
colA_data
,
descr
,
colB
.
numel
(),
rowptrB_data
,
colB_data
,
descr
,
0
,
NULL
,
NULL
,
descr
,
rowptrC_data
,
&
nnzC
,
info
,
buffer
);
auto
colC
=
a
t
::
empty
(
nnzC
,
colA
.
options
());
auto
colC
=
t
orch
::
empty
(
nnzC
,
colA
.
options
());
auto
colC_data
=
colC
.
DATA_PTR
<
int
>
();
if
(
!
valueA
.
has_value
()
&&
valueB
.
has_value
())
valueA
=
a
t
::
ones_like
(
valueB
.
value
());
valueA
=
t
orch
::
ones_like
(
valueB
.
value
());
if
(
!
valueB
.
has_value
()
&&
valueA
.
has_value
())
valueB
=
a
t
::
ones_like
(
valueA
.
value
());
valueB
=
t
orch
::
ones_like
(
valueA
.
value
());
a
t
::
optional
<
a
t
::
Tensor
>
valueC
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
valueC
=
t
orch
::
nullopt
;
if
(
valueA
.
has_value
())
valueC
=
a
t
::
empty
(
nnzC
,
valueA
.
value
().
options
());
valueC
=
t
orch
::
empty
(
nnzC
,
valueA
.
value
().
options
());
AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES
(
scalar_type
,
[
&
]
{
scalar_t
alpha
=
(
scalar_t
)
1
;
...
...
@@ -121,174 +121,8 @@ spspmm_cuda(at::Tensor rowptrA, at::Tensor colA,
descr
,
valueC_data
,
rowptrC_data
,
colC_data
,
info
,
buffer
);
});
rowptrC
=
rowptrC
.
toType
(
a
t
::
kLong
);
colC
=
colC
.
toType
(
a
t
::
kLong
);
rowptrC
=
rowptrC
.
toType
(
t
orch
::
kLong
);
colC
=
colC
.
toType
(
t
orch
::
kLong
);
return
std
::
make_tuple
(
rowptrC
,
colC
,
valueC
);
}
// #define THREADS 1024
// #define BLOCKS(N) (N + THREADS - 1) / THREADS
// #define CSRGEMM(TYPE, ...) \
// [&] { \
// const auto &the_type = TYPE; \
// (void)the_type; \
// at::ScalarType _st = ::detail::scalar_type(TYPE); \
// switch (_st) { \
// case at::ScalarType::Float: { \
// using scalar_t = float; \
// return cusparseScsrgemm(__VA_ARGS__); \
// } \
// case at::ScalarType::Double: { \
// using scalar_t = double; \
// return cusparseDcsrgemm(__VA_ARGS__); \
// } \
// default: \
// AT_ERROR("Not implemented for '", toString(_st), "'"); \
// } \
// }()
// static cusparseHandle_t cusparse_handle = 0;
// static void init_cusparse() {
// if (cusparse_handle == 0) {
// cusparseStatus_t status = cusparseCreate(&cusparse_handle);
// }
// }
// std::tuple<at::Tensor, at::Tensor>
// spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t m, size_t k, size_t n) {
// cudaSetDevice(indexA.get_device());
// init_cusparse();
// indexA = indexA.contiguous();
// valueA = valueA.contiguous();
// indexB = indexB.contiguous();
// valueB = valueB.contiguous();
// auto nnzA = valueA.size(0);
// auto nnzB = valueB.size(0);
// indexA = indexA.toType(at::kInt);
// indexB = indexB.toType(at::kInt);
// // Convert A to CSR format.
// auto row_ptrA = at::empty(m + 1, indexA.options());
// cusparseXcoo2csr(cusparse_handle, indexA[0].DATA_PTR<int>(), nnzA, k,
// row_ptrA.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colA = indexA[1];
// cudaMemcpy(row_ptrA.DATA_PTR<int>() + m, &nnzA, sizeof(int),
// cudaMemcpyHostToDevice);
// // Convert B to CSR format.
// auto row_ptrB = at::empty(k + 1, indexB.options());
// cusparseXcoo2csr(cusparse_handle, indexB[0].DATA_PTR<int>(), nnzB, k,
// row_ptrB.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto colB = indexB[1];
// cudaMemcpy(row_ptrB.DATA_PTR<int>() + k, &nnzB, sizeof(int),
// cudaMemcpyHostToDevice);
// cusparseMatDescr_t descr = 0;
// cusparseCreateMatDescr(&descr);
// cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL);
// cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO);
// int nnzC;
// auto row_ptrC = at::empty(m + 1, indexB.options());
// cusparseXcsrgemmNnz(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr,
// nnzA, row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB,
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(),
// descr, row_ptrC.DATA_PTR<int>(), &nnzC);
// auto colC = at::empty(nnzC, indexA.options());
// auto valueC = at::empty(nnzC, valueA.options());
// CSRGEMM(valueC.scalar_type(), cusparse_handle,
// CUSPARSE_OPERATION_NON_TRANSPOSE,
// CUSPARSE_OPERATION_NON_TRANSPOSE, m, n, k, descr, nnzA,
// valueA.DATA_PTR<scalar_t>(), row_ptrA.DATA_PTR<int>(),
// colA.DATA_PTR<int>(), descr, nnzB, valueB.DATA_PTR<scalar_t>(),
// row_ptrB.DATA_PTR<int>(), colB.DATA_PTR<int>(), descr,
// valueC.DATA_PTR<scalar_t>(), row_ptrC.DATA_PTR<int>(),
// colC.DATA_PTR<int>());
// auto rowC = at::empty(nnzC, indexA.options());
// cusparseXcsr2coo(cusparse_handle, row_ptrC.DATA_PTR<int>(), nnzC, m,
// rowC.DATA_PTR<int>(), CUSPARSE_INDEX_BASE_ZERO);
// auto indexC = at::stack({rowC, colC}, 0).toType(at::kLong);
// return std::make_tuple(indexC, valueC);
// }
// at::Tensor degree(at::Tensor row, int64_t num_nodes) {
// auto zero = at::zeros(num_nodes, row.options());
// auto one = at::ones(row.size(0), row.options());
// return zero.scatter_add_(0, row, one);
// }
// std::tuple<at::Tensor, at::Tensor> to_csr(at::Tensor row, at::Tensor col,
// int64_t num_nodes) {
// // Assert already coalesced input.
// row = degree(row, num_nodes).cumsum(0);
// row = at::cat({at::zeros(1, row.options()), row}, 0); // Prepend zero.
// return std::make_tuple(row, col);
// }
// template <typename scalar_t>
// __global__ void spspmm_bw_kernel(
// const int64_t *__restrict__ index, scalar_t *__restrict__ value,
// const int64_t *__restrict__ rowA, const int64_t *__restrict__ colA,
// const scalar_t *__restrict__ valueA, const int64_t *__restrict__
// rowB, const int64_t *__restrict__ colB, const scalar_t *__restrict__
// valueB, const size_t numel) {
// const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
// const size_t stride = blockDim.x * gridDim.x;
// for (ptrdiff_t e = idx; e < numel; e += stride) {
// int64_t i = index[e], j = index[numel + e];
// for (ptrdiff_t dA = rowA[i]; dA < rowA[i + 1]; dA++) {
// int64_t cA = colA[dA];
// for (ptrdiff_t dB = rowB[j]; dB < rowB[j + 1]; dB++) {
// int64_t cB = colB[dB];
// if (cA == cB) {
// value[e] += valueA[dA] * valueB[dB];
// }
// if (cB >= cA) {
// break;
// }
// }
// }
// }
// }
// at::Tensor spspmm_bw_cuda(at::Tensor index, at::Tensor indexA,
// at::Tensor valueA, at::Tensor indexB,
// at::Tensor valueB, size_t rowA_max, size_t
// rowB_max) {
// cudaSetDevice(index.get_device());
// auto value = at::zeros(index.size(1), valueA.options());
// at::Tensor rowA, colA;
// std::tie(rowA, colA) = to_csr(indexA[0], indexA[1], rowA_max);
// at::Tensor rowB, colB;
// std::tie(rowB, colB) = to_csr(indexB[0], indexB[1], rowB_max);
// AT_DISPATCH_FLOATING_TYPES(valueA.scalar_type(), "spspmm_bw", [&] {
// spspmm_bw_kernel<scalar_t><<<BLOCKS(value.numel()), THREADS>>>(
// index.DATA_PTR<int64_t>(), value.DATA_PTR<scalar_t>(),
// rowA.DATA_PTR<int64_t>(), colA.DATA_PTR<int64_t>(),
// valueA.DATA_PTR<scalar_t>(), rowB.DATA_PTR<int64_t>(),
// colB.DATA_PTR<int64_t>(), valueB.DATA_PTR<scalar_t>(),
// value.numel());
// });
// return value;
// }
torch_sparse/add.py
View file @
c4484dbb
import
torch
from
torch_scatter
import
gather_csr
def
is_scalar
(
other
):
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
from
torch_sparse.utils
import
is_scalar
def
sparse_add
(
matA
,
matB
):
...
...
torch_sparse/diag.py
View file @
c4484dbb
import
torch
from
.utils
import
ext
from
torch_sparse
.utils
import
ext
def
remove_diag
(
src
,
k
=
0
):
...
...
torch_sparse/matmul.py
View file @
c4484dbb
import
torch
import
scipy.sparse
from
torch_sparse
import
spmm_cpu
from
torch_scatter
import
scatter_add
try
:
from
torch_sparse
import
spmm_cuda
except
ImportError
:
spmm_cuda
=
None
try
:
from
torch_sparse
import
spspmm_cuda
except
ImportError
:
spspmm_cuda
=
None
def
spmm
(
is_cuda
):
return
spmm_cuda
if
is_cuda
else
spmm_cpu
from
torch_sparse.utils
import
ext
class
SPMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
csr2csc
,
reduce
):
out
,
arg_out
=
spmm
(
mat
.
is_cuda
).
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
out
,
arg_out
=
ext
(
mat
.
is_cuda
).
spmm
(
rowptr
,
col
,
value
,
mat
,
reduce
)
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
row
,
rowptr
,
col
,
value
,
mat
,
rowcount
,
colptr
,
...
...
@@ -48,7 +34,7 @@ class SPMM(torch.autograd.Function):
grad_value
=
None
if
ctx
.
needs_input_grad
[
3
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
,
'mean'
]:
grad_value
=
spmm
(
grad_out
.
is_cuda
).
spmm_val_bw
(
grad_value
=
ext
(
grad_out
.
is_cuda
).
spmm_val_bw
(
row
,
rowptr
,
col
,
mat
,
grad_out
,
ctx
.
reduce
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
...
...
@@ -63,7 +49,7 @@ class SPMM(torch.autograd.Function):
if
ctx
.
needs_input_grad
[
4
]:
if
ctx
.
reduce
in
[
'sum'
,
'add'
]:
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
grad_mat
,
_
=
ext
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
row
[
csr2csc
],
value
,
grad_out
,
'sum'
)
elif
ctx
.
reduce
==
'mean'
:
...
...
@@ -71,7 +57,7 @@ class SPMM(torch.autograd.Function):
value
=
count
.
pow_
(
-
1
)
if
value
is
None
else
value
/
count
row
=
row
[
csr2csc
]
value
=
value
[
csr2csc
]
if
value
is
not
None
else
value
grad_mat
,
_
=
spmm
(
grad_out
.
is_cuda
).
spmm
(
grad_mat
,
_
=
ext
(
grad_out
.
is_cuda
).
spmm
(
colptr
,
row
,
value
,
grad_out
,
'sum'
)
elif
ctx
.
reduce
in
[
'min'
,
'max'
]:
...
...
@@ -92,9 +78,9 @@ class SPSPMM(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
):
if
rowptrA
.
is_cuda
:
rowptrC
,
colC
,
valueC
=
spspmm_cuda
.
spspmm
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
)
rowptrC
,
colC
,
valueC
=
ext
(
True
)
.
spspmm
(
rowptrA
,
colA
,
valueA
,
rowptrB
,
colB
,
valueB
,
M
,
N
,
K
)
else
:
dtype
=
None
if
valueA
is
not
None
:
...
...
@@ -149,7 +135,7 @@ def matmul(src, other, reduce='sum'):
row
=
None
if
reduce
in
[
'sum'
,
'add'
,
'mean'
]
and
(
src
.
requires_grad
or
other
.
re
u
qires_grad
):
or
other
.
req
u
ires_grad
):
row
=
src
.
storage
.
row
rowcount
=
None
...
...
torch_sparse/mul.py
View file @
c4484dbb
import
torch
from
torch_scatter
import
gather_csr
def
is_scalar
(
other
):
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
from
torch_sparse.utils
import
is_scalar
def
mul
(
src
,
other
):
...
...
torch_sparse/storage.py
View file @
c4484dbb
...
...
@@ -2,7 +2,7 @@ import warnings
import
torch
from
torch_scatter
import
segment_csr
,
scatter_add
from
.utils
import
ext
from
torch_sparse
.utils
import
ext
__cache__
=
{
'enabled'
:
True
}
...
...
torch_sparse/tensor.py
View file @
c4484dbb
...
...
@@ -15,6 +15,7 @@ from torch_sparse.diag import remove_diag, set_diag
from
torch_sparse.matmul
import
matmul
from
torch_sparse.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
torch_sparse.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
torch_sparse.utils
import
is_scalar
class
SparseTensor
(
object
):
...
...
@@ -501,10 +502,14 @@ TORCH_MINOR = int(torch.__version__.split('.')[1])
if
(
TORCH_MAJOR
<
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
4
):
def
add
(
self
,
other
):
return
self
.
add
(
other
)
if
torch
.
is_tensor
(
other
)
else
NotImplemented
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
return
self
.
add
(
other
)
return
NotImplemented
def
mul
(
self
,
other
):
return
self
.
mul
(
other
)
if
torch
.
is_tensor
(
other
)
else
NotImplemented
if
torch
.
is_tensor
(
other
)
or
is_scalar
(
other
):
return
self
.
mul
(
other
)
return
NotImplemented
torch
.
Tensor
.
__add__
=
add
torch
.
Tensor
.
__mul__
=
add
torch
.
Tensor
.
__mul__
=
mul
torch_sparse/utils.py
View file @
c4484dbb
...
...
@@ -2,10 +2,13 @@ import torch
torch
.
ops
.
load_library
(
'torch_sparse/convert_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cpu.so'
)
try
:
torch
.
ops
.
load_library
(
'torch_sparse/convert_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/diag_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spmm_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_sparse/spspmm_cuda.so'
)
except
OSError
as
e
:
if
torch
.
cuda
.
is_available
():
raise
e
...
...
@@ -14,3 +17,7 @@ except OSError as e:
def
ext
(
is_cuda
):
name
=
'torch_sparse_cuda'
if
is_cuda
else
'torch_sparse_cpu'
return
getattr
(
torch
.
ops
,
name
)
def
is_scalar
(
other
):
return
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment