Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
06933f89
Commit
06933f89
authored
Jul 30, 2018
by
rusty1s
Browse files
cuda sparse sparse mm implementation
parent
41458598
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
97 additions
and
70 deletions
+97
-70
backprop.py
backprop.py
+12
-0
cuda/matmul.cpp
cuda/matmul.cpp
+5
-5
cuda/matmul_cuda.cu
cuda/matmul_cuda.cu
+68
-45
test/test_matmul.py
test/test_matmul.py
+11
-18
torch_sparse/matmul.py
torch_sparse/matmul.py
+1
-2
No files found.
backprop.py
0 → 100644
View file @
06933f89
import
torch
from
torch.autograd
import
Variable
size
=
torch
.
Size
([
2
,
2
])
index
=
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
long
).
cuda
()
value
=
torch
.
tensor
([
1
,
1
],
dtype
=
torch
.
float
).
cuda
()
A
=
torch
.
cuda
.
sparse
.
FloatTensor
(
index
,
value
,
size
)
index
=
torch
.
tensor
([[
0
,
1
],
[
0
,
1
]],
dtype
=
torch
.
long
)
value
=
torch
.
tensor
([
1
,
1
],
dtype
=
torch
.
float
)
B
=
torch
.
sparse
.
FloatTensor
(
index
,
value
,
size
)
cuda/matmul.cpp
View file @
06933f89
...
...
@@ -2,12 +2,12 @@
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
at
::
Tensor
spspmm_cuda
(
at
::
Tensor
matrix1
,
at
::
Tensor
matrix2
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
A
,
at
::
Tensor
B
);
at
::
Tensor
spspmm
(
at
::
Tensor
matrix1
,
at
::
Tensor
matrix2
)
{
CHECK_CUDA
(
matrix1
);
CHECK_CUDA
(
matrix2
);
return
spspmm_cuda
(
matrix1
,
matrix2
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm
(
at
::
Tensor
A
,
at
::
Tensor
B
)
{
CHECK_CUDA
(
A
);
CHECK_CUDA
(
B
);
return
spspmm_cuda
(
A
,
B
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cuda/matmul_cuda.cu
View file @
06933f89
...
...
@@ -2,6 +2,23 @@
#include <cusparse.h>
#define CSRGEMM(TYPE, ...) \
[&] { \
const at::Type &the_type = TYPE; \
switch (the_type.scalarType()) { \
case at::ScalarType::Float: { \
using scalar_t = float; \
return cusparseScsrgemm(__VA_ARGS__); \
} \
case at::ScalarType::Double: { \
using scalar_t = double; \
return cusparseDcsrgemm(__VA_ARGS__); \
} \
default: \
AT_ERROR("Not implemented for '%s'", the_type.toString()); \
} \
}()
static
cusparseHandle_t
cusparse_handle
=
0
;
static
void
init_cusparse
()
{
...
...
@@ -10,51 +27,57 @@ static void init_cusparse() {
}
}
at
::
Tensor
spspmm_cuda
(
at
::
Tensor
matrix1
,
at
::
Tensor
matrix2
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
A
,
at
::
Tensor
B
)
{
init_cusparse
();
auto
nnz
=
matrix1
.
_nnz
();
auto
inDim
=
matrix1
.
size
(
1
);
auto
row
=
matrix1
.
_indices
()[
0
].
toType
(
at
::
kInt
);
auto
row_ptrs
=
at
::
empty
(
row
.
type
(),
{
inDim
+
1
});
cusparseXcoo2csr
(
cusparse_handle
,
row
.
data
<
int
>
(),
nnz
,
inDim
,
row_ptrs
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
printf
(
"%lli
\n
"
,
nnz
);
printf
(
"%lli
\n
"
,
inDim
);
/* colbuf at::empty(nnz); */
/* auto colPtrs = at::empty(inDim + 1, at::kInt); */
/* auto row = matrix1._indices(); */
/* for (int i = 0; i < 5; i++) { */
/* row_buf.data<int>()[i] = (int)row.data<int64_t>()[i]; */
/* } */
/* printf("%lli\n", row.numel()); */
return
matrix1
;
auto
m
=
A
.
size
(
0
);
auto
n
=
B
.
size
(
1
);
auto
k
=
A
.
size
(
1
);
auto
nnzA
=
A
.
_nnz
();
auto
nnzB
=
B
.
_nnz
();
auto
valueA
=
A
.
_values
();
auto
indexA
=
A
.
_indices
().
toType
(
at
::
kInt
);
auto
row_ptrA
=
at
::
empty
(
indexA
.
type
(),
{
m
+
1
});
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
data
<
int
>
(),
nnzA
,
k
,
row_ptrA
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colA
=
indexA
[
1
];
auto
valueB
=
B
.
_values
();
auto
indexB
=
B
.
_indices
().
toType
(
at
::
kInt
);
auto
row_ptrB
=
at
::
empty
(
indexB
.
type
(),
{
k
+
1
});
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
data
<
int
>
(),
nnzB
,
k
,
row_ptrB
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colB
=
indexB
[
1
];
cusparseMatDescr_t
descr
=
0
;
cusparseCreateMatDescr
(
&
descr
);
cusparseSetMatType
(
descr
,
CUSPARSE_MATRIX_TYPE_GENERAL
);
cusparseSetMatIndexBase
(
descr
,
CUSPARSE_INDEX_BASE_ZERO
);
int
nnzC
;
auto
row_ptrC
=
at
::
empty
(
indexA
.
type
(),
{
m
+
1
});
cusparseXcsrgemmNnz
(
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
row_ptrB
.
data
<
int
>
(),
colB
.
data
<
int
>
(),
descr
,
row_ptrC
.
data
<
int
>
(),
&
nnzC
);
auto
colC
=
at
::
empty
(
indexA
.
type
(),
{
nnzC
});
auto
valueC
=
at
::
empty
(
valueA
.
type
(),
{
nnzC
});
CSRGEMM
(
valueC
.
type
(),
cusparse_handle
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
CUSPARSE_OPERATION_NON_TRANSPOSE
,
m
,
n
,
k
,
descr
,
nnzA
,
valueA
.
data
<
scalar_t
>
(),
row_ptrA
.
data
<
int
>
(),
colA
.
data
<
int
>
(),
descr
,
nnzB
,
valueB
.
data
<
scalar_t
>
(),
row_ptrB
.
data
<
int
>
(),
colB
.
data
<
int
>
(),
descr
,
valueC
.
data
<
scalar_t
>
(),
row_ptrC
.
data
<
int
>
(),
colC
.
data
<
int
>
());
auto
rowC
=
at
::
empty
(
indexA
.
type
(),
{
nnzC
});
cusparseXcsr2coo
(
cusparse_handle
,
row_ptrC
.
data
<
int
>
(),
nnzC
,
m
,
rowC
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
indexC
=
at
::
stack
({
rowC
,
colC
},
0
).
toType
(
at
::
kLong
);
return
std
::
make_tuple
(
indexC
,
valueC
);
}
/* #include <ATen/SparseTensorImpl.h> */
/* namespace at { */
/* namespace native { */
/* using SparseTensor = Tensor; */
/* namespace { */
/* at::SparseTensor spspmm_cuda(at::SparseTensor matrix1, */
/* at::SparseTensor matrix2) { */
/* return matrix1; */
/* } */
/* } // namespace */
/* } // namespace native */
/* } // namespace at */
// defined in aten/src/THCUNN/SparseLinear.cu as
/* cusparseXcoo2csr(cusparse_handle, THCudaIntTensor_data(state, colbuf), nnz,
*/
/* inDim, THCudaIntTensor_data(state, colPtrs), */
/* CUSPARSE_INDEX_BASE_ONE); */
test/test_matmul.py
View file @
06933f89
...
...
@@ -2,31 +2,24 @@ from itertools import product
import
pytest
import
torch
from
torch_sparse
import
spspmm
from
torch_sparse
import
spspmm
,
SparseTensor
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cuda'
)]
dtypes
=
[
torch
.
float
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
e1
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
v
1
=
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
,
device
)
matrix1
=
(
e1
,
v1
,
torch
.
Size
([
3
,
3
]))
index
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
v
alue
=
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
,
device
)
A
=
(
index
,
value
,
torch
.
Size
([
3
,
3
]))
e2
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]],
device
=
device
)
v
2
=
tensor
([
2
,
4
],
dtype
,
device
)
matrix2
=
(
e2
,
v2
,
torch
.
Size
([
3
,
2
]))
index
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]],
device
=
device
)
v
alue
=
tensor
([
2
,
4
],
dtype
,
device
)
B
=
(
index
,
value
,
torch
.
Size
([
3
,
2
]))
index
,
value
=
spspmm
(
*
matrix1
,
*
matrix2
)
print
(
index
)
print
(
value
)
# out = torch.sparse_coo_tensor(index, value, torch.Size([3, 2]), dtype)
# out = out.to_dense()
# print(out)
# assert out.tolist() == [[8, 0], [0, 6], [0, 8]]
index
,
value
=
spspmm
(
*
A
,
*
B
)
out
=
SparseTensor
(
index
,
value
,
torch
.
Size
([
3
,
2
]))
assert
out
.
to_dense
().
tolist
()
==
[[
8
,
0
],
[
0
,
6
],
[
0
,
8
]]
# value.sum().backward()
# TODO TEST backward
# value.sum().backward()
torch_sparse/matmul.py
View file @
06933f89
...
...
@@ -46,8 +46,7 @@ def mm(e1, v1, s1, e2, v2, s2):
def
mm_cuda
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
matrix1
=
SparseTensor
(
e1
,
v1
,
s1
)
matrix2
=
SparseTensor
(
e2
,
v2
,
s2
)
out
=
matmul_cuda
.
spspmm
(
matrix1
,
matrix2
)
return
out
.
_indices
(),
out
.
_values
()
return
matmul_cuda
.
spspmm
(
matrix1
,
matrix2
)
def
mm_cpu
(
e1
,
v1
,
s1
,
e2
,
v2
,
s2
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment