Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
fa763bac
Commit
fa763bac
authored
Jan 25, 2020
by
rusty1s
Browse files
add implementation cpu
parent
f00ca88b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
152 additions
and
46 deletions
+152
-46
test/test_add.py
test/test_add.py
+52
-6
torch_sparse/add.py
torch_sparse/add.py
+100
-40
No files found.
test/test_add.py
View file @
fa763bac
import
time
from
itertools
import
product
from
itertools
import
product
from
scipy.io
import
loadmat
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.add
import
add
from
torch_sparse.add
import
sparse_
add
from
.utils
import
dtypes
,
devices
,
tensor
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
'cpu'
]
dtypes
=
[
torch
.
float
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_sparse_add
(
dtype
,
device
):
def
test_sparse_add
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
name
=
(
'DIMACS10'
,
'citationCiteseer'
)[
1
]
mat1
=
SparseTensor
(
index
)
mat_scipy
=
loadmat
(
f
'benchmark/
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
mat
=
SparseTensor
.
from_scipy
(
mat_scipy
)
mat1
=
mat
[:,
0
:
100000
]
mat2
=
mat
[:,
100000
:
200000
]
print
(
mat1
.
shape
)
print
(
mat2
.
shape
)
# 0.0159 to beat
t
=
time
.
perf_counter
()
mat
=
sparse_add
(
mat1
,
mat2
)
print
(
time
.
perf_counter
()
-
t
)
print
(
mat
.
nnz
())
mat1
=
mat_scipy
[:,
0
:
100000
]
mat2
=
mat_scipy
[:,
100000
:
200000
]
t
=
time
.
perf_counter
()
mat
=
mat1
+
mat2
print
(
time
.
perf_counter
()
-
t
)
print
(
mat
.
nnz
)
# mat1 + mat2
# mat1 = mat1.tocoo()
# mat2 = mat2.tocoo()
# row1, col1 = mat1.row, mat1.col
# row2, col2 = mat2.row, mat2.col
# idx1 = row1 * 100000 + col1
# idx2 = row2 * 100000 + col2
# t = time.perf_counter()
# np.union1d(idx1, idx2)
# print(time.perf_counter() - t)
# index = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
# mat1 = SparseTensor(index)
# print()
# print(mat1.to_dense())
index
=
tensor
([[
0
,
0
,
1
,
2
],
[
0
,
1
,
1
,
0
]],
torch
.
long
,
device
)
# index = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
mat2
=
SparseTensor
(
index
)
# mat2 = SparseTensor(index)
# print(mat2.to_dense())
add
(
mat1
,
mat2
)
#
add(mat1, mat2)
torch_sparse/add.py
View file @
fa763bac
import
torch
import
torch
from
torch_scatter
import
gather_csr
def
union
(
mat1
,
mat2
):
def
sparse_add
(
matA
,
matB
):
offset
=
mat1
.
nnz
()
+
1
nnzA
,
nnzB
=
matA
.
nnz
(),
matB
.
nnz
()
value1
=
torch
.
ones
(
mat1
.
nnz
(),
dtype
=
torch
.
long
,
device
=
mat2
.
device
)
valA
=
torch
.
full
((
nnzA
,
),
1
,
dtype
=
torch
.
uint8
,
device
=
matA
.
device
)
value2
=
value1
.
new_full
((
mat2
.
nnz
(),
),
offset
)
valB
=
torch
.
full
((
nnzB
,
),
2
,
dtype
=
torch
.
uint8
,
device
=
matB
.
device
)
size
=
max
(
mat1
.
size
(
0
),
mat2
.
size
(
0
)),
max
(
mat1
.
size
(
1
),
mat2
.
size
(
1
))
if
not
mat1
.
is_cuda
:
if
matA
.
is_cuda
:
mat1
=
mat1
.
set_value
(
value1
,
layout
=
'coo'
).
to_scipy
(
layout
=
'csr'
)
pass
mat1
.
resize
(
*
size
)
else
:
matA_
=
matA
.
set_value
(
valA
,
layout
=
'csr'
).
to_scipy
(
layout
=
'csr'
)
matB_
=
matB
.
set_value
(
valB
,
layout
=
'csr'
).
to_scipy
(
layout
=
'csr'
)
matC_
=
matA_
+
matB_
rowptr
=
torch
.
from_numpy
(
matC_
.
indptr
).
to
(
torch
.
long
)
matC_
=
matC_
.
tocoo
()
row
=
torch
.
from_numpy
(
matC_
.
row
).
to
(
torch
.
long
)
col
=
torch
.
from_numpy
(
matC_
.
col
).
to
(
torch
.
long
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
valC_
=
torch
.
from_numpy
(
matC_
.
data
)
mat2
=
mat2
.
set_value
(
value2
,
layout
=
'coo'
).
to_scipy
(
layout
=
'csr'
)
value
=
None
mat2
.
resize
(
*
size
)
if
matA
.
has_value
()
or
matB
.
has_value
():
maskA
,
maskB
=
valC_
!=
2
,
valC_
>=
2
out
=
mat1
+
mat2
size
=
matA
.
size
()
if
matA
.
dim
()
>=
matB
.
dim
()
else
matA
.
size
()
rowptr
=
torch
.
from_numpy
(
out
.
indptr
).
to
(
torch
.
long
)
size
=
(
valC_
.
size
(
0
),
)
+
size
[
2
:]
out
=
out
.
tocoo
()
row
=
torch
.
from_numpy
(
out
.
row
).
to
(
torch
.
long
)
value
=
torch
.
zeros
(
size
,
dtype
=
matA
.
dtype
,
device
=
matA
.
device
)
col
=
torch
.
from_numpy
(
out
.
col
).
to
(
torch
.
long
)
value
[
maskA
]
+=
matA
.
storage
.
value
if
matA
.
has_value
()
else
1
value
=
torch
.
from_numpy
(
out
.
data
)
value
[
maskB
]
+=
matB
.
storage
.
value
if
matB
.
has_value
()
else
1
else
:
raise
NotImplementedError
mask1
=
value
%
offset
>
0
storage
=
matA
.
storage
.
__class__
(
index
,
value
,
matA
.
sparse_size
(),
mask2
=
value
>=
offset
rowptr
=
rowptr
,
is_sorted
=
True
)
return
rowptr
,
torch
.
stack
([
row
,
col
],
dim
=
0
),
mask1
,
mask2
return
matA
.
__class__
.
from_storage
(
storage
)
def
add
(
src
,
other
):
def
add
(
src
,
other
):
...
@@ -35,19 +43,21 @@ def add(src, other):
...
@@ -35,19 +43,21 @@ def add(src, other):
elif
torch
.
is_tensor
(
other
):
elif
torch
.
is_tensor
(
other
):
(
row
,
col
),
value
=
src
.
coo
()
(
row
,
col
),
value
=
src
.
coo
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
val
=
other
.
squeeze
(
1
).
repeat_interleave
(
other
=
gather_csr
(
other
.
squeeze
(
1
),
src
.
storage
.
rowptr
)
row
,
0
)
+
(
value
if
src
.
has_value
()
else
1
)
value
=
other
.
add_
(
src
.
storage
.
value
if
src
.
has_value
()
else
1
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
return
src
.
set_value
(
value
,
layout
=
'csr'
)
val
=
other
.
squeeze
(
0
)[
col
]
+
(
value
if
src
.
has_value
()
else
1
)
else
:
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
other
=
other
.
squeeze
(
0
)[
col
]
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
value
=
other
.
add_
(
src
.
storage
.
value
if
src
.
has_value
()
else
1
)
f
'
{
other
.
size
()
}
.'
)
return
src
.
set_value
(
value
,
layout
=
'coo'
)
return
src
.
set_value
(
val
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
elif
isinstance
(
other
,
src
.
__class__
):
rowptr
,
index
,
src_offset
,
other_offset
=
union
(
src
,
other
)
raise
NotImplementedError
raise
NotImplementedError
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
...
@@ -55,21 +65,71 @@ def add(src, other):
...
@@ -55,21 +65,71 @@ def add(src, other):
def
add_
(
src
,
other
):
def
add_
(
src
,
other
):
pass
if
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
):
return
add_nnz_
(
src
,
other
)
elif
torch
.
is_tensor
(
other
):
(
row
,
col
),
value
=
src
.
coo
()
if
other
.
size
(
0
)
==
src
.
size
(
0
)
and
other
.
size
(
1
)
==
1
:
# Row-wise...
other
=
gather_csr
(
other
.
squeeze
(
1
),
src
.
storage
.
rowptr
)
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
value
=
other
.
add_
(
1
)
return
src
.
set_value_
(
value
,
layout
=
'csr'
)
if
other
.
size
(
0
)
==
1
and
other
.
size
(
1
)
==
src
.
size
(
1
):
# Col-wise...
other
=
other
.
squeeze
(
0
)[
col
]
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
value
=
other
.
add_
(
1
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
f
'Size mismatch: Expected size (
{
src
.
size
(
0
)
}
, 1,'
f
' ...) or (1,
{
src
.
size
(
1
)
}
, ...), but got size '
f
'
{
other
.
size
()
}
.'
)
elif
isinstance
(
other
,
src
.
__class__
):
raise
NotImplementedError
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
def
add_nnz
(
src
,
other
,
layout
=
None
):
def
add_nnz
(
src
,
other
,
layout
=
None
):
if
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
):
if
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
):
return
src
.
set_value
(
src
.
storage
.
value
+
if
src
.
has_value
():
other
if
src
.
has_value
()
else
torch
.
full
((
value
=
src
.
storage
.
value
+
other
src
.
nnz
(),
),
1
+
other
,
device
=
src
.
device
))
else
:
elif
torch
.
is_tensor
(
other
):
value
=
torch
.
full
((
src
.
nnz
(),
),
1
+
other
,
device
=
src
.
device
)
return
src
.
set_value
(
src
.
storage
.
value
+
return
src
.
set_value
(
value
,
layout
=
'coo'
)
other
if
src
.
has_value
()
else
other
+
1
)
if
torch
.
is_tensor
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
+
other
else
:
value
=
other
+
1
return
src
.
set_value
(
value
,
layout
=
'coo'
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.'
)
'`torch.tensor`.'
)
def
add_nnz_
(
src
,
other
,
layout
=
None
):
def
add_nnz_
(
src
,
other
,
layout
=
None
):
pass
if
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
value
=
torch
.
full
((
src
.
nnz
(),
),
1
+
other
,
device
=
src
.
device
)
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
if
torch
.
is_tensor
(
other
):
if
src
.
has_value
():
value
=
src
.
storage
.
value
.
add_
(
other
)
else
:
value
=
other
+
1
# No inplace operation possible.
return
src
.
set_value_
(
value
,
layout
=
'coo'
)
raise
ValueError
(
'Argument `other` needs to be of type `int`, `float` or '
'`torch.tensor`.'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment