Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
43b284f1
Commit
43b284f1
authored
Feb 03, 2020
by
rusty1s
Browse files
clean up
parent
7636e1d1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
114 additions
and
447 deletions
+114
-447
csrc/cpu/spspmm_cpu.cpp
csrc/cpu/spspmm_cpu.cpp
+13
-8
csrc/cuda/spspmm_cuda.cu
csrc/cuda/spspmm_cuda.cu
+0
-39
test/test_jit.py
test/test_jit.py
+0
-91
test/test_spspmm.py
test/test_spspmm.py
+2
-18
test/test_spspmm_spmm.py
test/test_spspmm_spmm.py
+0
-22
test/test_storage.py
test/test_storage.py
+33
-58
torch_sparse/__init__.py
torch_sparse/__init__.py
+42
-16
torch_sparse/coalesce.py
torch_sparse/coalesce.py
+6
-26
torch_sparse/spmm.py
torch_sparse/spmm.py
+0
-76
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+9
-83
torch_sparse/transpose.py
torch_sparse/transpose.py
+9
-10
No files found.
csrc/cpu/spspmm_cpu.cpp
View file @
43b284f1
...
...
@@ -53,19 +53,24 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
auto
rowptrC_data
=
rowptrC
.
data_ptr
<
int64_t
>
();
rowptrC_data
[
0
]
=
0
;
int64_t
rowA_start
=
0
,
rowA_end
,
rowB_start
,
rowB_end
,
cA
,
cB
;
int64_t
nnz
=
0
,
row_nnz
;
for
(
auto
n
=
1
;
n
<
rowptrA
.
numel
();
n
++
)
{
row
A_end
=
rowptrA_data
[
n
]
;
std
::
vector
<
int64_t
>
mask
(
K
,
-
1
)
;
int64_t
nnz
=
0
,
row_nnz
,
rowA_start
,
rowA_end
,
rowB_start
,
rowB_end
,
cA
,
cB
;
for
(
auto
n
=
0
;
n
<
rowptrA
.
numel
()
-
1
;
n
++
)
{
row
_nnz
=
0
;
for
(
auto
eA
=
row
A_start
;
eA
<
rowA_end
;
eA
++
)
{
for
(
auto
eA
=
row
ptrA_data
[
n
];
eA
<
rowptrA_data
[
n
+
1
]
;
eA
++
)
{
cA
=
colA_data
[
eA
];
row_nnz
=
rowptrB_data
[
cA
+
1
]
-
rowptrB_data
[
cA
];
for
(
auto
eB
=
rowptrB_data
[
cA
];
eB
<
rowptrB_data
[
cA
+
1
];
eB
++
)
{
cB
=
colB_data
[
eB
];
if
(
mask
[
cB
]
!=
n
)
{
mask
[
cB
]
=
n
;
row_nnz
++
;
}
}
}
nnz
+=
row_nnz
;
rowptrC_data
[
n
]
=
nnz
;
rowA_start
=
rowA_end
;
rowptrC_data
[
n
+
1
]
=
nnz
;
}
// Pass 2: Compute CSR entries.
...
...
csrc/cuda/spspmm_cuda.cu
View file @
43b284f1
...
...
@@ -27,44 +27,6 @@
} \
}()
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_BUFFER_SIZE_EXT_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
#define AT_DISPATCH_CUSPARSE_CSR_GEMM2_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
spspmm_cuda
(
torch
::
Tensor
rowptrA
,
torch
::
Tensor
colA
,
torch
::
optional
<
torch
::
Tensor
>
optional_valueA
,
...
...
@@ -108,7 +70,6 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
scalar_type
=
optional_valueA
.
value
().
scalar_type
();
auto
handle
=
at
::
cuda
::
getCurrentCUDASparseHandle
();
cusparseSetPointerMode
(
handle
,
CUSPARSE_POINTER_MODE_HOST
);
cusparseMatDescr_t
descr
;
cusparseCreateMatDescr
(
&
descr
);
...
...
test/test_jit.py
deleted
100644 → 0
View file @
7636e1d1
import
torch
from
torch_sparse
import
SparseStorage
,
SparseTensor
from
typing
import
Dict
,
Any
# class MyTensor(dict):
# def __init__(self, rowptr, col):
# self['rowptr'] = rowptr
# self['col'] = col
# def rowptr(self: Dict[str, torch.Tensor]):
# return self['rowptr']
@
torch
.
jit
.
script
class
Foo
:
rowptr
:
torch
.
Tensor
col
:
torch
.
Tensor
def
__init__
(
self
,
rowptr
:
torch
.
Tensor
,
col
:
torch
.
Tensor
):
self
.
rowptr
=
rowptr
self
.
col
=
col
class
MyCell
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
MyCell
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
2
,
4
)
# def forward(self, x: torch.Tensor, ptr: torch.Tensor) -> torch.Tensor:
def
forward
(
self
,
x
:
torch
.
Tensor
,
adj
:
SparseTensor
)
->
torch
.
Tensor
:
out
,
_
=
torch
.
ops
.
torch_sparse_cpu
.
spmm
(
adj
.
storage
.
rowptr
(),
adj
.
storage
.
col
(),
None
,
x
,
'sum'
)
return
out
# ind = torch.ops.torch_sparse_cpu.ptr2ind(ptr, ptr[-1].item())
# # ind = ptr2ind(ptr, E)
# x_j = x[ind]
# out = self.linear(x_j)
# return out
def
test_jit
():
my_cell
=
MyCell
()
# x = torch.rand(3, 2)
# ptr = torch.tensor([0, 2, 4, 6])
# out = my_cell(x, ptr)
# print()
# print(out)
# traced_cell = torch.jit.trace(my_cell, (x, ptr))
# print(traced_cell)
# out = traced_cell(x, ptr)
# print(out)
x
=
torch
.
randn
(
3
,
2
)
# adj = torch.randn(3, 3)
# adj = SparseTensor.from_dense(adj)
# adj = Foo(adj.storage.rowptr, adj.storage.col)
# adj = adj.storage
rowptr
=
torch
.
tensor
([
0
,
1
,
4
,
7
])
col
=
torch
.
tensor
([
0
,
0
,
1
,
2
,
0
,
1
,
2
])
adj
=
SparseTensor
(
rowptr
=
rowptr
,
col
=
col
)
# scipy = adj.to_scipy(layout='csr')
# mat = SparseTensor.from_scipy(scipy)
print
()
# adj = t(adj)
adj
=
adj
.
t
()
adj
=
adj
.
remove_diag
(
k
=
0
)
print
(
adj
.
to_dense
())
adj
=
adj
+
torch
.
tensor
([
1
,
2
,
3
]).
view
(
1
,
3
)
print
(
adj
)
print
(
adj
.
to_dense
())
# print(adj.t)
# adj = {'rowptr': mat.storage.rowptr, 'col': mat.storage.col}
# foo = Foo(mat.storage.rowptr, mat.storage.col)
# adj = MyTensor(mat.storage.rowptr, mat.storage.col)
traced_cell
=
torch
.
jit
.
script
(
my_cell
)
print
(
traced_cell
)
out
=
traced_cell
(
x
,
adj
)
print
(
out
)
# # print(traced_cell.code)
test/test_spspmm.py
View file @
43b284f1
...
...
@@ -4,32 +4,16 @@ import pytest
import
torch
from
torch_sparse
import
spspmm
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
grad_
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_
dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
indexA
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
valueA
=
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
,
device
)
sizeA
=
torch
.
Size
([
3
,
3
])
indexB
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]],
device
=
device
)
valueB
=
tensor
([
2
,
4
],
dtype
,
device
)
sizeB
=
torch
.
Size
([
3
,
2
])
indexC
,
valueC
=
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
3
,
3
,
2
)
assert
indexC
.
tolist
()
==
[[
0
,
1
,
2
],
[
0
,
1
,
1
]]
assert
valueC
.
tolist
()
==
[
8
,
6
,
8
]
A
=
torch
.
sparse_coo_tensor
(
indexA
,
valueA
,
sizeA
,
device
=
device
)
A
=
A
.
to_dense
().
requires_grad_
()
B
=
torch
.
sparse_coo_tensor
(
indexB
,
valueB
,
sizeB
,
device
=
device
)
B
=
B
.
to_dense
().
requires_grad_
()
torch
.
matmul
(
A
,
B
).
sum
().
backward
()
valueA
=
valueA
.
requires_grad_
()
valueB
=
valueB
.
requires_grad_
()
indexC
,
valueC
=
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
3
,
3
,
2
)
valueC
.
sum
().
backward
()
assert
valueA
.
grad
.
tolist
()
==
A
.
grad
[
indexA
[
0
],
indexA
[
1
]].
tolist
()
assert
valueB
.
grad
.
tolist
()
==
B
.
grad
[
indexB
[
0
],
indexB
[
1
]].
tolist
()
test/test_spspmm_spmm.py
deleted
100644 → 0
View file @
7636e1d1
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse
import
spspmm
,
spmm
from
.utils
import
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm_spspmm
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
2
,
2
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
2
,
1
,
0
,
1
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
tensor
([
1
,
2
,
4
,
1
,
3
],
dtype
,
device
)
x
=
tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]],
dtype
,
device
)
value
=
value
.
requires_grad_
(
True
)
out_index
,
out_value
=
spspmm
(
index
,
value
,
index
,
value
,
3
,
3
,
3
)
out
=
spmm
(
out_index
,
out_value
,
3
,
3
,
x
)
assert
out
.
size
()
==
(
3
,
2
)
test/test_storage.py
View file @
43b284f1
import
copy
from
itertools
import
product
import
pytest
...
...
@@ -13,18 +12,18 @@ def test_storage(dtype, device):
row
,
col
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
)
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
is
None
assert
storage
.
sparse_size
==
(
2
,
2
)
assert
storage
.
row
()
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
()
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
()
is
None
assert
storage
.
sparse_size
s
()
==
(
2
,
2
)
row
,
col
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
value
=
tensor
([
2
,
1
,
4
,
3
],
dtype
,
device
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
sparse_size
==
(
2
,
2
)
assert
storage
.
row
()
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
()
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
()
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
sparse_size
s
()
==
(
2
,
2
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
...
...
@@ -41,7 +40,7 @@ def test_caching(dtype, device):
assert
storage
.
_colcount
is
None
assert
storage
.
_colptr
is
None
assert
storage
.
_csr2csc
is
None
assert
storage
.
cached_keys
()
==
[]
assert
storage
.
num_
cached_keys
()
==
0
storage
.
fill_cache_
()
assert
storage
.
_rowcount
.
tolist
()
==
[
2
,
2
]
...
...
@@ -50,16 +49,14 @@ def test_caching(dtype, device):
assert
storage
.
_colptr
.
tolist
()
==
[
0
,
2
,
4
]
assert
storage
.
_csr2csc
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
_csc2csr
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
cached_keys
()
==
[
'rowcount'
,
'colptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
assert
storage
.
num_cached_keys
()
==
5
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
storage
.
rowptr
,
col
=
col
,
value
=
storage
.
value
,
sparse_size
=
storage
.
sparse_size
,
rowcount
=
storage
.
rowcount
,
colptr
=
storage
.
colptr
,
colcount
=
storage
.
colcount
,
csr2csc
=
storage
.
csr2csc
,
csc2csr
=
storage
.
csc2csr
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
storage
.
_
rowptr
,
col
=
col
,
value
=
storage
.
_
value
,
sparse_size
s
=
storage
.
_
sparse_size
s
,
rowcount
=
storage
.
_
rowcount
,
colptr
=
storage
.
_
colptr
,
colcount
=
storage
.
_
colcount
,
csr2csc
=
storage
.
_csr2csc
,
csc2csr
=
storage
.
_
csc2csr
)
assert
storage
.
_rowcount
.
tolist
()
==
[
2
,
2
]
assert
storage
.
_rowptr
.
tolist
()
==
[
0
,
2
,
4
]
...
...
@@ -67,9 +64,7 @@ def test_caching(dtype, device):
assert
storage
.
_colptr
.
tolist
()
==
[
0
,
2
,
4
]
assert
storage
.
_csr2csc
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
_csc2csr
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
cached_keys
()
==
[
'rowcount'
,
'colptr'
,
'colcount'
,
'csr2csc'
,
'csc2csr'
]
assert
storage
.
num_cached_keys
()
==
5
storage
.
clear_cache_
()
assert
storage
.
_rowcount
is
None
...
...
@@ -77,7 +72,7 @@ def test_caching(dtype, device):
assert
storage
.
_colcount
is
None
assert
storage
.
_colptr
is
None
assert
storage
.
_csr2csc
is
None
assert
storage
.
cached_keys
()
==
[]
assert
storage
.
num_
cached_keys
()
==
0
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
...
...
@@ -89,45 +84,25 @@ def test_utility(dtype, device):
assert
storage
.
has_value
()
storage
.
set_value_
(
value
,
layout
=
'csc'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
3
,
2
,
4
]
assert
storage
.
value
()
.
tolist
()
==
[
1
,
3
,
2
,
4
]
storage
.
set_value_
(
value
,
layout
=
'coo'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
value
()
.
tolist
()
==
[
1
,
2
,
3
,
4
]
storage
=
storage
.
set_value
(
value
,
layout
=
'csc'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
3
,
2
,
4
]
assert
storage
.
value
()
.
tolist
()
==
[
1
,
3
,
2
,
4
]
storage
=
storage
.
set_value
(
value
,
layout
=
'coo'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
value
()
.
tolist
()
==
[
1
,
2
,
3
,
4
]
storage
=
storage
.
sparse_resize
(
3
,
3
)
assert
storage
.
sparse_size
==
(
3
,
3
)
storage
=
storage
.
sparse_resize
(
[
3
,
3
]
)
assert
storage
.
sparse_size
s
()
==
[
3
,
3
]
new_storage
=
copy
.
copy
(
storage
)
new_storage
=
storage
.
copy
(
)
assert
new_storage
!=
storage
assert
new_storage
.
col
.
data_ptr
()
==
storage
.
col
.
data_ptr
()
assert
new_storage
.
col
()
.
data_ptr
()
==
storage
.
col
()
.
data_ptr
()
new_storage
=
storage
.
clone
()
assert
new_storage
!=
storage
assert
new_storage
.
col
.
data_ptr
()
!=
storage
.
col
.
data_ptr
()
new_storage
=
copy
.
deepcopy
(
storage
)
assert
new_storage
!=
storage
assert
new_storage
.
col
.
data_ptr
()
!=
storage
.
col
.
data_ptr
()
storage
.
apply_value_
(
lambda
x
:
x
+
1
)
assert
storage
.
value
.
tolist
()
==
[
2
,
3
,
4
,
5
]
storage
=
storage
.
apply_value
(
lambda
x
:
x
+
1
)
assert
storage
.
value
.
tolist
()
==
[
3
,
4
,
5
,
6
]
storage
.
apply_
(
lambda
x
:
x
.
to
(
torch
.
long
))
assert
storage
.
col
.
dtype
==
torch
.
long
assert
storage
.
value
.
dtype
==
torch
.
long
storage
=
storage
.
apply
(
lambda
x
:
x
.
to
(
torch
.
long
))
assert
storage
.
col
.
dtype
==
torch
.
long
assert
storage
.
value
.
dtype
==
torch
.
long
storage
.
clear_cache_
()
assert
storage
.
map
(
lambda
x
:
x
.
numel
())
==
[
4
,
4
,
4
]
assert
new_storage
.
col
().
data_ptr
()
!=
storage
.
col
().
data_ptr
()
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
...
...
@@ -136,14 +111,14 @@ def test_coalesce(dtype, device):
value
=
tensor
([
1
,
1
,
1
,
3
,
4
],
dtype
,
device
)
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
)
assert
storage
.
row
.
tolist
()
==
row
.
tolist
()
assert
storage
.
col
.
tolist
()
==
col
.
tolist
()
assert
storage
.
value
.
tolist
()
==
value
.
tolist
()
assert
storage
.
row
()
.
tolist
()
==
row
.
tolist
()
assert
storage
.
col
()
.
tolist
()
==
col
.
tolist
()
assert
storage
.
value
()
.
tolist
()
==
value
.
tolist
()
assert
not
storage
.
is_coalesced
()
storage
=
storage
.
coalesce
()
assert
storage
.
is_coalesced
()
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
row
()
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
()
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
()
.
tolist
()
==
[
1
,
2
,
3
,
4
]
torch_sparse/__init__.py
View file @
43b284f1
from
.storage
import
SparseStorage
from
.tensor
import
SparseTensor
from
.transpose
import
t
from
.narrow
import
narrow
from
.select
import
select
from
.index_select
import
index_select
,
index_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.diag
import
remove_diag
,
set_diag
,
fill_diag
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.reduce
import
sum
,
mean
,
min
,
max
from
.matmul
import
matmul
from
.cat
import
cat
,
cat_diag
from
.convert
import
to_torch_sparse
,
from_torch_sparse
,
to_scipy
,
from_scipy
from
.coalesce
import
coalesce
from
.transpose
import
transpose
...
...
@@ -8,7 +22,33 @@ from .spspmm import spspmm
__version__
=
'0.4.3'
__all__
=
[
'__version__'
,
'SparseStorage'
,
'SparseTensor'
,
't'
,
'narrow'
,
'select'
,
'index_select'
,
'index_select_nnz'
,
'masked_select'
,
'masked_select_nnz'
,
'remove_diag'
,
'set_diag'
,
'fill_diag'
,
'add'
,
'add_'
,
'add_nnz'
,
'add_nnz_'
,
'mul'
,
'mul_'
,
'mul_nnz'
,
'mul_nnz_'
,
'sum'
,
'mean'
,
'min'
,
'max'
,
'matmul'
,
'cat'
,
'cat_diag'
,
'to_torch_sparse'
,
'from_torch_sparse'
,
'to_scipy'
,
...
...
@@ -18,19 +58,5 @@ __all__ = [
'eye'
,
'spmm'
,
'spspmm'
,
'__version__'
,
]
from
.storage
import
SparseStorage
from
.tensor
import
SparseTensor
from
.transpose
import
t
from
.narrow
import
narrow
from
.select
import
select
from
.index_select
import
index_select
,
index_select_nnz
from
.masked_select
import
masked_select
,
masked_select_nnz
from
.diag
import
remove_diag
,
set_diag
,
fill_diag
from
.add
import
add
,
add_
,
add_nnz
,
add_nnz_
from
.mul
import
mul
,
mul_
,
mul_nnz
,
mul_nnz_
from
.reduce
import
sum
,
mean
,
min
,
max
from
.matmul
import
(
spmm_sum
,
spmm_add
,
spmm_mean
,
spmm_min
,
spmm_max
,
spmm
,
spspmm_sum
,
spspmm_add
,
spspmm
,
matmul
)
from
.cat
import
cat
,
cat_diag
torch_sparse/coalesce.py
View file @
43b284f1
import
torch
import
torch_s
catter
from
torch_s
parse.storage
import
SparseStorage
# from .unique import unique
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
'add'
,
fill_value
=
0
):
def
coalesce
(
index
,
value
,
m
,
n
,
op
=
"add"
):
"""Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
entries are removed by scattering them together. For scattering, any
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
...
...
@@ -17,29 +15,11 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
n (int): The second dimension of corresponding dense matrix.
op (string, optional): The scatter operation to use. (default:
:obj:`"add"`)
fill_value (int, optional): The initial fill value of scatter
operation. (default: :obj:`0`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise
NotImplementedError
row
,
col
=
index
if
value
is
None
:
_
,
perm
=
unique
(
row
*
n
+
col
)
index
=
torch
.
stack
([
row
[
perm
],
col
[
perm
]],
dim
=
0
)
return
index
,
value
uniq
,
inv
=
torch
.
unique
(
row
*
n
+
col
,
sorted
=
True
,
return_inverse
=
True
)
perm
=
torch
.
arange
(
inv
.
size
(
0
),
dtype
=
inv
.
dtype
,
device
=
inv
.
device
)
perm
=
inv
.
new_empty
(
uniq
.
size
(
0
)).
scatter_
(
0
,
inv
,
perm
)
index
=
torch
.
stack
([
row
[
perm
],
col
[
perm
]],
dim
=
0
)
op
=
getattr
(
torch_scatter
,
'scatter_{}'
.
format
(
op
))
value
=
op
(
value
,
inv
,
0
,
None
,
perm
.
size
(
0
),
fill_value
)
if
isinstance
(
value
,
tuple
):
value
=
value
[
0
]
return
index
,
value
storage
=
SparseStorage
(
row
=
index
[
0
],
col
=
index
[
1
],
value
=
value
,
sparse_sizes
=
torch
.
Size
([
m
,
n
],
is_sorted
=
False
))
storage
=
storage
.
coalesce
(
reduce
=
op
)
return
torch
.
stack
([
storage
.
row
(),
storage
.
col
()],
dim
=
0
),
storage
.
value
()
torch_sparse/spmm.py
View file @
43b284f1
# import torch
from
torch_scatter
import
scatter_add
# from torch_sparse.tensor import SparseTensor
# if torch.cuda.is_available():
# import torch_sparse.spmm_cuda
# def spmm_(sparse_mat, mat, reduce='add'):
# assert reduce in ['add', 'mean', 'min', 'max']
# assert sparse_mat.dim() == 2 and mat.dim() == 2
# assert sparse_mat.size(1) == mat.size(0)
# rowptr, col, value = sparse_mat.csr()
# mat = mat.contiguous()
# if reduce in ['add', 'mean']:
# return torch_sparse.spmm_cuda.spmm(rowptr, col, value, mat, reduce)
# else:
# return torch_sparse.spmm_cuda.spmm_arg(
# rowptr, col, value, mat, reduce)
def
spmm
(
index
,
value
,
m
,
n
,
matrix
):
"""Matrix product of sparse matrix with dense matrix.
...
...
@@ -44,60 +25,3 @@ def spmm(index, value, m, n, matrix):
out
=
scatter_add
(
out
,
row
,
dim
=
0
,
dim_size
=
m
)
return
out
# if __name__ == '__main__':
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# row = torch.tensor([0, 0, 0, 1, 1, 1], device=device)
# col = torch.tensor([0, 1, 2, 0, 1, 2], device=device)
# value = torch.ones_like(col, dtype=torch.float, device=device)
# value = None
# sparse_mat = SparseTensor(torch.stack([row, col], dim=0), value)
# mat = torch.tensor([[1, 4], [2, 5], [3, 6]], dtype=torch.float,
# device=device)
# out1 = spmm_(sparse_mat, mat, reduce='add')
# out2 = sparse_mat.to_dense() @ mat
# assert torch.allclose(out1, out2)
# from torch_geometric.datasets import Reddit, Planetoid # noqa
# import time # noqa
# # Warmup
# x = torch.randn((1000, 1000), device=device)
# for _ in range(100):
# x.sum()
# # dataset = Reddit('/tmp/Reddit')
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
# # dataset = Planetoid('/tmp/Cora', 'Cora')
# data = dataset[0].to(device)
# mat = torch.randn((data.num_nodes, 1024), device=device)
# value = torch.ones(data.num_edges, device=device)
# sparse_mat = SparseTensor(data.edge_index, value)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out1 = spmm_(sparse_mat, mat, reduce='add')
# out1 = out1[0] if isinstance(out1, tuple) else out1
# torch.cuda.synchronize()
# print('My: ', time.perf_counter() - t)
# sparse_mat = torch.sparse_coo_tensor(data.edge_index, value)
# sparse_mat = sparse_mat.coalesce()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out2 = sparse_mat @ mat
# torch.cuda.synchronize()
# print('Torch: ', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# spmm(data.edge_index, value, data.num_nodes, data.num_nodes, mat)
# torch.cuda.synchronize()
# print('Scatter:', time.perf_counter() - t)
# assert torch.allclose(out1, out2, atol=1e-2)
torch_sparse/spspmm.py
View file @
43b284f1
import
torch
from
torch_sparse
import
transpose
,
to_scipy
,
from_scipy
,
coalesce
# import torch_sparse.spspmm_cpu
# if torch.cuda.is_available():
# import torch_sparse.spspmm_cuda
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.matmul
import
matmul
def
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
,
coalesced
=
False
):
...
...
@@ -25,83 +21,13 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
raise
NotImplementedError
if
indexA
.
is_cuda
and
coalesced
:
indexA
,
valueA
=
coalesce
(
indexA
,
valueA
,
m
,
k
)
indexB
,
valueB
=
coalesce
(
indexB
,
valueB
,
k
,
n
)
index
,
value
=
SpSpMM
.
apply
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
return
index
.
detach
(),
value
class
SpSpMM
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
):
indexC
,
valueC
=
mm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
ctx
.
m
,
ctx
.
k
,
ctx
.
n
=
m
,
k
,
n
ctx
.
save_for_backward
(
indexA
,
valueA
,
indexB
,
valueB
,
indexC
)
return
indexC
,
valueC
@
staticmethod
def
backward
(
ctx
,
grad_indexC
,
grad_valueC
):
m
,
k
=
ctx
.
m
,
ctx
.
k
n
=
ctx
.
n
indexA
,
valueA
,
indexB
,
valueB
,
indexC
=
ctx
.
saved_tensors
grad_valueA
=
grad_valueB
=
None
if
not
grad_valueC
.
is_cuda
:
if
ctx
.
needs_input_grad
[
1
]
or
ctx
.
needs_input_grad
[
1
]:
grad_valueC
=
grad_valueC
.
clone
()
if
ctx
.
needs_input_grad
[
1
]:
grad_valueA
=
torch_sparse
.
spspmm_cpu
.
spspmm_bw
(
indexA
,
indexC
.
detach
(),
grad_valueC
,
indexB
.
detach
(),
valueB
,
m
,
k
)
if
ctx
.
needs_input_grad
[
3
]:
indexA
,
valueA
=
transpose
(
indexA
,
valueA
,
m
,
k
)
indexC
,
grad_valueC
=
transpose
(
indexC
,
grad_valueC
,
m
,
n
)
grad_valueB
=
torch_sparse
.
spspmm_cpu
.
spspmm_bw
(
indexB
,
indexA
.
detach
(),
valueA
,
indexC
.
detach
(),
grad_valueC
,
k
,
n
)
else
:
if
ctx
.
needs_input_grad
[
1
]:
grad_valueA
=
torch_sparse
.
spspmm_cuda
.
spspmm_bw
(
indexA
,
indexC
.
detach
(),
grad_valueC
.
clone
(),
indexB
.
detach
(),
valueB
,
m
,
k
)
if
ctx
.
needs_input_grad
[
3
]:
indexA_T
,
valueA_T
=
transpose
(
indexA
,
valueA
,
m
,
k
)
grad_indexB
,
grad_valueB
=
mm
(
indexA_T
,
valueA_T
,
indexC
,
grad_valueC
,
k
,
m
,
n
)
grad_valueB
=
lift
(
grad_indexB
,
grad_valueB
,
indexB
,
n
)
return
None
,
grad_valueA
,
None
,
grad_valueB
,
None
,
None
,
None
def
mm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
):
assert
valueA
.
dtype
==
valueB
.
dtype
if
indexA
.
is_cuda
:
return
torch_sparse
.
spspmm_cuda
.
spspmm
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
A
=
to_scipy
(
indexA
,
valueA
,
m
,
k
)
B
=
to_scipy
(
indexB
,
valueB
,
k
,
n
)
C
=
A
.
dot
(
B
).
tocoo
().
tocsr
().
tocoo
()
# Force coalesce.
indexC
,
valueC
=
from_scipy
(
C
)
return
indexC
,
valueC
def
lift
(
indexA
,
valueA
,
indexB
,
n
):
# pragma: no cover
idxA
=
indexA
[
0
]
*
n
+
indexA
[
1
]
idxB
=
indexB
[
0
]
*
n
+
indexB
[
1
]
max_value
=
max
(
idxA
.
max
().
item
(),
idxB
.
max
().
item
())
+
1
valueB
=
valueA
.
new_zeros
(
max_value
)
A
=
SparseTensor
(
row
=
indexA
[
0
],
col
=
indexA
[
1
],
value
=
valueA
,
sparse_sizes
=
torch
.
Size
([
m
,
k
]),
is_sorted
=
not
coalesced
)
B
=
SparseTensor
(
row
=
indexB
[
0
],
col
=
indexB
[
1
],
value
=
valueB
,
sparse_sizes
=
torch
.
Size
([
k
,
n
]),
is_sorted
=
not
coalesced
)
valueB
[
idxA
]
=
valueA
value
B
=
valueB
[
idxB
]
C
=
matmul
(
A
,
B
)
row
,
col
,
value
=
C
.
coo
()
return
value
B
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
value
torch_sparse/transpose.py
View file @
43b284f1
import
torch
from
torch_sparse
import
to_scipy
,
from_scipy
,
coalesce
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.tensor
import
SparseTensor
...
...
@@ -51,14 +50,14 @@ def transpose(index, value, m, n, coalesced=True):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
if
value
.
dim
()
==
1
and
not
value
.
is_cuda
:
mat
=
to_scipy
(
index
,
value
,
m
,
n
).
tocsc
()
(
col
,
row
),
value
=
from_scipy
(
mat
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
return
index
,
value
row
,
col
=
index
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)
row
,
col
=
col
,
row
if
coalesced
:
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
sparse_sizes
=
torch
.
Size
([
n
,
m
])
storage
=
SparseStorage
(
row
=
row
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
is_sorted
=
False
)
storage
=
storage
.
coalesce
()
row
,
col
,
value
=
storage
.
row
(),
storage
.
col
(),
storage
.
value
()
return
torch
.
stack
([
row
,
col
],
dim
=
0
),
value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment