Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
296b5048
Commit
296b5048
authored
Jul 31, 2018
by
rusty1s
Browse files
gradients
parent
76872e45
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
142 additions
and
89 deletions
+142
-89
cuda/matmul_cuda.cu
cuda/matmul_cuda.cu
+6
-2
setup.cfg
setup.cfg
+1
-1
test/test_matmul.py
test/test_matmul.py
+49
-27
test/utils.py
test/utils.py
+1
-1
torch_sparse/__init__.py
torch_sparse/__init__.py
+3
-2
torch_sparse/matmul.py
torch_sparse/matmul.py
+32
-49
torch_sparse/sparse.py
torch_sparse/sparse.py
+50
-0
torch_sparse/transpose.py
torch_sparse/transpose.py
+0
-7
No files found.
cuda/matmul_cuda.cu
View file @
296b5048
...
@@ -28,11 +28,11 @@ static void init_cusparse() {
...
@@ -28,11 +28,11 @@ static void init_cusparse() {
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
A
,
at
::
Tensor
B
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
spspmm_cuda
(
at
::
Tensor
A
,
at
::
Tensor
B
)
{
init_cusparse
();
A
=
A
.
coalesce
();
A
=
A
.
coalesce
();
B
=
B
.
coalesce
();
B
=
B
.
coalesce
();
init_cusparse
();
auto
m
=
A
.
size
(
0
);
auto
m
=
A
.
size
(
0
);
auto
n
=
B
.
size
(
1
);
auto
n
=
B
.
size
(
1
);
auto
k
=
A
.
size
(
1
);
auto
k
=
A
.
size
(
1
);
...
@@ -46,6 +46,8 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
...
@@ -46,6 +46,8 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
data
<
int
>
(),
nnzA
,
k
,
cusparseXcoo2csr
(
cusparse_handle
,
indexA
[
0
].
data
<
int
>
(),
nnzA
,
k
,
row_ptrA
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
row_ptrA
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colA
=
indexA
[
1
];
auto
colA
=
indexA
[
1
];
cudaMemcpy
(
row_ptrA
.
data
<
int
>
()
+
m
,
&
nnzA
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
auto
valueB
=
B
.
_values
();
auto
valueB
=
B
.
_values
();
auto
indexB
=
B
.
_indices
().
toType
(
at
::
kInt
);
auto
indexB
=
B
.
_indices
().
toType
(
at
::
kInt
);
...
@@ -53,6 +55,8 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
...
@@ -53,6 +55,8 @@ std::tuple<at::Tensor, at::Tensor> spspmm_cuda(at::Tensor A, at::Tensor B) {
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
data
<
int
>
(),
nnzB
,
k
,
cusparseXcoo2csr
(
cusparse_handle
,
indexB
[
0
].
data
<
int
>
(),
nnzB
,
k
,
row_ptrB
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
row_ptrB
.
data
<
int
>
(),
CUSPARSE_INDEX_BASE_ZERO
);
auto
colB
=
indexB
[
1
];
auto
colB
=
indexB
[
1
];
cudaMemcpy
(
row_ptrB
.
data
<
int
>
()
+
k
,
&
nnzB
,
sizeof
(
int
),
cudaMemcpyHostToDevice
);
cusparseMatDescr_t
descr
=
0
;
cusparseMatDescr_t
descr
=
0
;
cusparseCreateMatDescr
(
&
descr
);
cusparseCreateMatDescr
(
&
descr
);
...
...
setup.cfg
View file @
296b5048
...
@@ -5,4 +5,4 @@ description-file = README.md
...
@@ -5,4 +5,4 @@ description-file = README.md
test = pytest
test = pytest
[tool:pytest]
[tool:pytest]
addopts = --capture=no
--cov
addopts = --capture=no
test/test_matmul.py
View file @
296b5048
...
@@ -3,51 +3,73 @@ from itertools import product
...
@@ -3,51 +3,73 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_sparse
import
spspmm
from
torch_sparse
import
SparseTensor
,
spspmm
,
to_value
from
torch_sparse.matmul
import
SpSpMM
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cpu'
)]
dtypes
=
[
torch
.
double
]
dtypes
=
[
torch
.
double
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_coalesced_spspmm
(
dtype
,
device
):
def
test_coalesced_spspmm
(
dtype
,
device
):
indexA
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
indexA
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
valueA
=
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
,
device
,
requires_grad
=
True
)
valueA
=
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
,
device
)
sizeA
=
torch
.
Size
([
3
,
3
])
sizeA
=
torch
.
Size
([
3
,
3
])
A
=
(
indexA
,
valueA
,
sizeA
)
A
=
torch
.
sparse_coo_tensor
(
indexA
,
valueA
,
sizeA
,
device
=
device
)
A_dense
=
torch
.
sparse_coo_tensor
(
indexA
,
valueA
,
sizeA
).
to_dense
()
A_dense
=
A_dense
.
requires_grad_
()
print
(
'A'
,
A_dense
)
indexB
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]],
device
=
device
)
indexB
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]],
device
=
device
)
valueB
=
tensor
([
2
,
4
],
dtype
,
device
,
requires_grad
=
True
)
valueB
=
tensor
([
2
,
4
],
dtype
,
device
)
sizeB
=
torch
.
Size
([
3
,
2
])
sizeB
=
torch
.
Size
([
3
,
2
])
B
=
(
indexB
,
valueB
,
sizeB
)
B
=
torch
.
sparse_coo_tensor
(
indexB
,
valueB
,
sizeB
,
device
=
device
)
B_dense
=
torch
.
sparse_coo_tensor
(
indexB
,
valueB
,
sizeB
).
to_dense
()
B_dense
=
B_dense
.
requires_grad_
()
index
,
value
,
size
=
spspmm
(
*
A
,
*
B
)
assert
spspmm
(
A
,
B
).
to_dense
().
tolist
()
==
[[
8
,
0
],
[
0
,
6
],
[
0
,
8
]]
# out = torch.sparse_coo_tensor(index, value, size)
expected
=
torch
.
matmul
(
A_dense
,
B_dense
)
# assert out.to_dense().tolist() == expected.tolist()
# valueA = valueA.requires_grad_()
# A.requires_grad_()
# valueB = valueB.requires_grad_()
# B.requires_grad_()
# data = (indexA, valueA, sizeA, indexB, valueB, sizeB)
# assert gradcheck(SpSpMM.apply, data, eps=1e-6, atol=1e-4) is True
# print(expected)
# A.requires_grad_()
# B.requires_grad_()
value
.
sum
().
backward
()
# to_value(C).sum().backward()
expected
.
sum
().
backward
()
# print(valueA)
# print(valueA.grad)
# print(valueB)
# print(valueB.grad)
print
(
valueA
.
grad
)
# A_dense.requires_
grad
_(
)
print
(
A_dense
.
grad
)
# B_dense.requires_
grad
_(
)
# print(valueB.grad)
# C_dense = torch.matmul(A_dense, B_dense)
# C_dense[C_dense > 0].sum().backward()
# print(A_dense)
# print(A_dense.grad)
# print(B_dense)
# print(B_dense.grad)
# print(B_dense.grad)
# # TODO TEST backward
# A.requires_grad_()
# # value.sum().backward()
# B = B.to_dense()
# B.requires_grad_()
# torch.spmm(A, B).sum().backward()
# print(B.grad)
# valueA.requires_grad_()
valueB
.
requires_grad_
()
def
pipeline
(
valueA
,
valueB
):
A
=
SparseTensor
(
indexA
,
valueA
,
sizeA
)
B
=
SparseTensor
(
indexB
,
valueB
,
sizeB
)
C
=
spspmm
(
A
,
B
)
value
=
to_value
(
C
)
return
value
# out = pipeline(valueA, valueB).sum().backward()
# print(valueA.grad)
# print(valueB.grad)
print
(
gradcheck
(
pipeline
,
(
valueA
,
valueB
),
eps
=
1e-6
,
atol
=
1e-4
))
# A, B = Sparsetensor(SparseTensor(index, valueB, sizeB)
# print(A.requires_grad)
# to_value(C).sum().backward()
test/utils.py
View file @
296b5048
...
@@ -8,5 +8,5 @@ if torch.cuda.is_available(): # pragma: no cover
...
@@ -8,5 +8,5 @@ if torch.cuda.is_available(): # pragma: no cover
def
tensor
(
x
,
dtype
,
device
,
requires_grad
=
False
):
def
tensor
(
x
,
dtype
,
device
,
requires_grad
=
False
):
return
None
if
x
is
None
else
torch
.
tensor
(
return
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
x
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
torch_sparse/__init__.py
View file @
296b5048
from
.
transpo
se
import
transpos
e
from
.
spar
se
import
SparseTensor
,
to_valu
e
from
.matmul
import
spspmm
from
.matmul
import
spspmm
__all__
=
[
__all__
=
[
'transpose'
,
'SparseTensor'
,
'to_value'
,
'spspmm'
,
'spspmm'
,
]
]
torch_sparse/matmul.py
View file @
296b5048
import
torch
import
torch
import
scipy.sparse
import
scipy.sparse
from
torch_sparse
import
transpose
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
import
matmul_cuda
import
matmul_cuda
def
spspmm
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
assert
valueA
.
dtype
==
valueB
.
dtype
assert
len
(
sizeA
)
==
len
(
sizeB
)
==
2
assert
sizeA
[
1
]
==
sizeB
[
0
]
index
,
value
=
SpSpMM
.
apply
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
size
=
torch
.
Size
([
sizeA
[
0
],
sizeB
[
1
]])
return
index
,
value
,
size
class
SpSpMM
(
torch
.
autograd
.
Function
):
class
SpSpMM
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
def
forward
(
ctx
,
A
,
B
):
index
,
value
=
mm
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
ctx
.
save_for_backward
(
A
,
B
)
return
mm
(
A
,
B
)
ctx
.
sizeA
,
ctx
.
sizeB
=
sizeA
,
sizeB
ctx
.
save_for_backward
(
indexA
,
valueA
,
indexB
,
valueB
,
index
)
return
index
,
value
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_index
,
grad_value
):
def
backward
(
ctx
,
grad_C
):
indexA
,
valueA
,
indexB
,
valueB
,
index
=
ctx
.
saved_variables
A
,
B
=
ctx
.
saved_variables
grad_valueA
=
grad_valueB
=
None
grad_A
=
grad_B
=
None
grad
=
(
index
,
grad_value
,
torch
.
Size
([
ctx
.
sizeA
[
0
],
ctx
.
sizeB
[
1
]]))
if
ctx
.
needs_input_grad
[
0
]:
grad_A
=
mm
(
grad_C
,
B
.
t
().
coalesce
())
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
B_tranposed
=
transpose
(
indexB
,
valueB
,
ctx
.
sizeB
)
grad_B
=
mm
(
A
.
t
(),
grad_C
)
_
,
grad_valueA
=
mm
(
*
grad
,
*
B_tranposed
)
if
ctx
.
needs_input_grad
[
4
]:
return
grad_A
,
grad_B
A_tranposed
=
transpose
(
indexA
,
valueA
,
ctx
.
sizeA
)
_
,
grad_valueB
=
mm
(
*
A_tranposed
,
*
grad
)
return
None
,
grad_valueA
,
None
,
None
,
grad_valueB
,
None
spspmm
=
SpSpMM
.
apply
def
mm
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
if
valueA
.
is_cuda
:
return
mm_cuda
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
else
:
return
mm_cpu
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
)
def
mm
(
A
,
B
):
assert
A
.
dtype
==
B
.
dtype
assert
A
.
size
(
1
)
==
B
.
size
(
0
)
return
mm_cuda
(
A
,
B
)
if
A
.
is_cuda
else
mm_cpu
(
A
,
B
)
def
mm_cuda
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
A
=
torch
.
sparse_coo_tensor
(
indexA
,
valueA
,
sizeA
)
B
=
torch
.
sparse_coo_tensor
(
indexB
,
valueB
,
sizeB
)
def
mm_cuda
(
A
,
B
):
index
,
value
=
matmul_cuda
.
spspmm
(
A
,
B
)
index
,
value
=
matmul_cuda
.
spspmm
(
A
,
B
)
size
=
torch
.
Size
([
A
.
size
(
0
),
B
.
size
(
1
)])
return
index
,
value
return
torch
.
sparse_coo_tensor
(
index
,
value
,
size
,
device
=
value
.
device
)
def
mm_cpu
(
indexA
,
valueA
,
sizeA
,
indexB
,
valueB
,
sizeB
):
def
mm_cpu
(
A
,
B
):
A
,
B
,
=
to_scipy
(
indexA
,
valueA
,
sizeA
),
to_scipy
(
indexB
,
valueB
,
sizeB
)
return
from_scipy
(
to_scipy
(
A
).
dot
(
to_scipy
(
B
)))
C
=
A
.
tocsr
().
dot
(
B
.
tocsr
()).
tocoo
()
row
,
col
=
torch
.
from_numpy
(
C
.
row
).
long
(),
torch
.
from_numpy
(
C
.
col
).
long
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
torch
.
from_numpy
(
C
.
data
).
type_as
(
valueA
)
return
index
,
value
def
to_scipy
(
A
):
(
row
,
col
),
data
,
shape
=
A
.
_indices
(),
A
.
_values
(),
tuple
(
A
.
size
())
row
,
col
,
data
=
row
.
detach
(),
col
.
detach
(),
data
.
detach
()
return
scipy
.
sparse
.
coo_matrix
((
data
,
(
row
,
col
)),
shape
).
tocsr
()
def
to_scipy
(
index
,
value
,
size
):
def
from_scipy
(
A
):
(
row
,
col
),
value
=
index
.
detach
().
numpy
(),
value
.
detach
().
numpy
()
A
=
A
.
tocoo
()
return
scipy
.
sparse
.
coo_matrix
((
value
,
(
row
,
col
)),
tuple
(
size
))
row
,
col
,
value
,
size
=
A
.
row
,
A
.
col
,
A
.
data
,
torch
.
Size
(
A
.
shape
)
value
=
torch
.
from_numpy
(
value
)
index
=
torch
.
stack
([
torch
.
from_numpy
(
row
),
torch
.
from_numpy
(
col
)],
dim
=
0
)
index
=
index
.
to
(
torch
.
long
)
return
torch
.
sparse_coo_tensor
(
index
,
value
,
size
)
torch_sparse/sparse.py
0 → 100644
View file @
296b5048
import
torch
class
_SparseTensor
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
index
,
value
,
size
):
ctx
.
size
=
size
ctx
.
save_for_backward
(
index
)
return
torch
.
sparse_coo_tensor
(
index
,
value
,
size
,
device
=
value
.
device
)
@
staticmethod
def
backward
(
ctx
,
grad_out
):
index
=
ctx
.
saved_variables
[
0
]
grad_in
=
None
if
ctx
.
needs_input_grad
[
1
]:
value
=
grad_out
.
_values
()
id1
=
index
[
0
]
*
ctx
.
size
[
1
]
+
index
[
1
]
index
=
grad_out
.
_indices
()
id2
=
index
[
0
]
*
ctx
.
size
[
1
]
+
index
[
1
]
grad_in
=
value
.
new_zeros
(
id1
.
max
().
item
()
+
1
)
grad_in
[
id2
]
=
value
grad_in
=
grad_in
[
id1
]
return
None
,
grad_in
,
None
SparseTensor
=
_SparseTensor
.
apply
class
ToValue
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
):
ctx
.
save_for_backward
(
A
)
return
A
.
_values
()
@
staticmethod
def
backward
(
ctx
,
grad_out
):
A
=
ctx
.
saved_variables
[
0
]
grad_in
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_in
=
torch
.
sparse_coo_tensor
(
A
.
_indices
(),
grad_out
,
A
.
size
(),
device
=
grad_out
.
device
)
return
grad_in
to_value
=
ToValue
.
apply
torch_sparse/transpose.py
deleted
100644 → 0
View file @
76872e45
import
torch
def
transpose
(
index
,
value
,
size
):
(
row
,
col
),
(
dim1
,
dim2
)
=
index
,
size
index
,
size
=
torch
.
stack
([
col
,
row
],
dim
=
0
),
torch
.
Size
([
dim2
,
dim1
])
return
index
,
value
,
size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment