Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
bb963e32
Commit
bb963e32
authored
Dec 19, 2019
by
rusty1s
Browse files
narrow implementation
parent
195053ef
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
105 additions
and
70 deletions
+105
-70
torch_sparse/narrow.py
torch_sparse/narrow.py
+27
-19
torch_sparse/storage.py
torch_sparse/storage.py
+9
-6
torch_sparse/tensor.py
torch_sparse/tensor.py
+62
-41
torch_sparse/transpose.py
torch_sparse/transpose.py
+7
-4
No files found.
torch_sparse/narrow.py
View file @
bb963e32
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
narrow
(
src
,
dim
,
start
,
length
):
def
narrow
(
src
,
dim
,
start
,
length
):
if
dim
==
0
:
if
dim
==
0
:
col
,
rowptr
,
value
=
src
.
c
sr
()
(
row
,
col
)
,
value
=
src
.
c
oo
()
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
,
length
=
length
)
rowptr
,
_
,
_
=
src
.
csr
(
)
row_start
,
row_end
=
rowptr
[
0
]
rowptr
=
rowptr
.
narrow
(
0
,
start
=
start
,
length
=
length
+
1
)
row_length
=
rowptr
[
-
1
]
-
row_start
row_start
=
rowptr
[
0
]
rowptr
=
rowptr
-
row_start
row_length
=
rowptr
[
-
1
]
row
=
row
.
narrow
(
0
,
row_start
,
row_length
)
-
start
col
=
col
.
narrow
(
0
,
row_start
,
row_length
)
col
=
col
.
narrow
(
0
,
row_start
,
row_length
)
row
=
self
.
_row
.
narrow
(
0
,
row_start
,
row_length
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
src
.
has_value
():
value
=
value
.
narrow
(
0
,
row_start
,
row_length
)
sparse_size
=
torch
.
Size
([
length
,
src
.
sparse_size
(
1
)])
storage
=
src
.
_storage
.
__class__
(
index
,
value
,
sparse_size
,
rowptr
,
is_sorted
=
True
)
elif
dim
==
1
:
# This is faster than accessing `csc()` in analogy to thr `dim=0` case.
(
row
,
col
),
value
=
src
.
coo
()
mask
=
(
col
>=
start
)
&
(
col
<
start
+
length
)
index
=
torch
.
stack
([
row
,
col
-
start
],
dim
=
0
)[:,
mask
]
if
src
.
has_value
():
value
=
value
[
mask
]
sparse_size
=
torch
.
Size
([
src
.
sparse_size
(
0
),
length
])
elif
dim
==
0
:
storage
=
src
.
_storage
.
__class__
(
index
,
value
,
sparse_size
,
is_sorted
=
True
)
else
:
else
:
storage
=
src
.
_storage
.
apply_value
(
lambda
x
:
x
.
narrow
(
dim
-
1
,
start
,
length
))
return
src
.
__class__
.
from_storage
(
storage
)
pass
if
__name__
==
'__main__'
:
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
],
device
=
device
)
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
],
device
=
device
)
sparse_mat
=
SparseTensor
(
torch
.
stack
([
row
,
col
],
dim
=
0
))
print
(
sparse_mat
)
print
(
sparse_mat
.
to_dense
())
torch_sparse/storage.py
View file @
bb963e32
...
@@ -166,6 +166,12 @@ class SparseStorage(object):
...
@@ -166,6 +166,12 @@ class SparseStorage(object):
def
coalesce
(
self
):
def
coalesce
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
cached_keys
(
self
):
return
[
key
for
key
in
self
.
cache_keys
if
getattr
(
self
,
f
'_
{
key
}
'
,
None
)
is
not
None
]
def
fill_cache_
(
self
,
*
args
):
def
fill_cache_
(
self
,
*
args
):
for
arg
in
args
or
self
.
cache_keys
:
for
arg
in
args
or
self
.
cache_keys
:
getattr
(
self
,
arg
)
getattr
(
self
,
arg
)
...
@@ -206,8 +212,8 @@ class SparseStorage(object):
...
@@ -206,8 +212,8 @@ class SparseStorage(object):
def
apply_
(
self
,
func
):
def
apply_
(
self
,
func
):
self
.
_index
=
func
(
self
.
_index
)
self
.
_index
=
func
(
self
.
_index
)
self
.
_value
=
optional
(
func
,
self
.
_value
)
self
.
_value
=
optional
(
func
,
self
.
_value
)
for
key
in
self
.
cache_keys
:
for
key
in
self
.
cache
d
_keys
()
:
setattr
(
self
,
f
'_
{
key
}
'
,
optional
(
func
,
getattr
(
self
,
f
'_
{
key
}
'
))
)
setattr
(
self
,
f
'_
{
key
}
'
,
func
,
getattr
(
self
,
f
'_
{
key
}
'
))
return
self
return
self
def
apply
(
self
,
func
):
def
apply
(
self
,
func
):
...
@@ -226,10 +232,7 @@ class SparseStorage(object):
...
@@ -226,10 +232,7 @@ class SparseStorage(object):
data
=
[
func
(
self
.
index
)]
data
=
[
func
(
self
.
index
)]
if
self
.
has_value
():
if
self
.
has_value
():
data
+=
[
func
(
self
.
value
)]
data
+=
[
func
(
self
.
value
)]
data
+=
[
data
+=
[
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cached_keys
()]
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cache_keys
if
getattr
(
self
,
f
'_
{
key
}
'
)
]
return
data
return
data
...
...
torch_sparse/tensor.py
View file @
bb963e32
...
@@ -4,7 +4,9 @@ import torch
...
@@ -4,7 +4,9 @@ import torch
import
scipy.sparse
import
scipy.sparse
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.transpose
import
t
from
torch_sparse.transpose
import
t
from
torch_sparse.narrow
import
narrow
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
...
@@ -77,8 +79,10 @@ class SparseTensor(object):
...
@@ -77,8 +79,10 @@ class SparseTensor(object):
return
self
.
_storage
.
is_coalesced
()
return
self
.
_storage
.
is_coalesced
()
def
coalesce
(
self
):
def
coalesce
(
self
):
storage
=
self
.
_storage
.
coalesce
()
return
self
.
__class__
.
from_storage
(
self
.
_storage
.
coalesce
())
return
self
.
__class__
.
from_storage
(
storage
)
def
cached_keys
(
self
):
return
self
.
_storage
.
cached_keys
()
def
fill_cache_
(
self
,
*
args
):
def
fill_cache_
(
self
,
*
args
):
self
.
_storage
.
fill_cache_
(
*
args
)
self
.
_storage
.
fill_cache_
(
*
args
)
...
@@ -139,7 +143,6 @@ class SparseTensor(object):
...
@@ -139,7 +143,6 @@ class SparseTensor(object):
def
detach
(
self
):
def
detach
(
self
):
storage
=
self
.
_storage
.
apply
(
lambda
x
:
x
.
detach
())
storage
=
self
.
_storage
.
apply
(
lambda
x
:
x
.
detach
())
print
(
"AWDAwd"
)
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
__class__
.
from_storage
(
storage
)
def
pin_memory
(
self
):
def
pin_memory
(
self
):
...
@@ -265,27 +268,29 @@ class SparseTensor(object):
...
@@ -265,27 +268,29 @@ class SparseTensor(object):
requires_grad
=
requires_grad
)
requires_grad
=
requires_grad
)
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
'coo'
):
def
to_scipy
(
self
,
dtype
=
None
,
layout
=
'coo'
):
assert
self
.
dim
()
==
2
assert
layout
in
self
.
_storage
.
layouts
assert
layout
in
self
.
_storage
.
layouts
self
=
self
.
detach
().
cpu
()
if
not
self
.
has_value
():
ones
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
if
self
.
has_value
():
value
=
self
.
_storage
.
value
.
numpy
()
assert
value
.
ndim
==
1
else
:
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
dtype
).
numpy
()
if
layout
==
'coo'
:
if
layout
==
'coo'
:
(
row
,
col
),
_
=
self
.
coo
()
(
row
,
col
),
value
=
self
.
coo
()
row
,
col
=
row
.
numpy
(),
col
.
numpy
()
row
=
row
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
coo_matrix
((
value
,
(
row
,
col
)),
self
.
size
())
return
scipy
.
sparse
.
coo_matrix
((
value
,
(
row
,
col
)),
self
.
size
())
elif
layout
==
'csr'
:
elif
layout
==
'csr'
:
rowptr
,
col
,
_
=
self
.
csr
()
rowptr
,
col
,
value
=
self
.
csr
()
rowptr
,
col
=
rowptr
.
numpy
(),
col
.
numpy
()
rowptr
=
rowptr
.
detach
().
cpu
().
numpy
()
col
=
col
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csr_matrix
((
value
,
col
,
rowptr
),
self
.
size
())
return
scipy
.
sparse
.
csr_matrix
((
value
,
col
,
rowptr
),
self
.
size
())
elif
layout
==
'csc'
:
elif
layout
==
'csc'
:
colptr
,
row
,
_
=
self
.
csc
()
colptr
,
row
,
value
=
self
.
csc
()
colptr
,
row
=
colptr
.
numpy
(),
row
.
numpy
()
colptr
=
colptr
.
detach
().
cpu
().
numpy
()
row
=
row
.
detach
().
cpu
().
numpy
()
value
=
value
.
detach
().
cpu
().
numpy
()
if
self
.
has_value
()
else
ones
return
scipy
.
sparse
.
csc_matrix
((
value
,
row
,
colptr
),
self
.
size
())
return
scipy
.
sparse
.
csc_matrix
((
value
,
row
,
colptr
),
self
.
size
())
# String Reputation #######################################################
# String Reputation #######################################################
...
@@ -312,6 +317,7 @@ class SparseTensor(object):
...
@@ -312,6 +317,7 @@ class SparseTensor(object):
# Bindings ####################################################################
# Bindings ####################################################################
SparseTensor
.
t
=
t
SparseTensor
.
t
=
t
SparseTensor
.
narrow
=
narrow
# def set_diag(self, value):
# def set_diag(self, value):
# raise NotImplementedError
# raise NotImplementedError
...
@@ -434,33 +440,48 @@ if __name__ == '__main__':
...
@@ -434,33 +440,48 @@ if __name__ == '__main__':
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
data
=
dataset
[
0
].
to
(
device
)
data
=
dataset
[
0
].
to
(
device
)
value
=
torch
.
ones
((
data
.
num_edges
,
),
device
=
device
)
value
=
torch
.
randn
((
data
.
num_edges
,
),
device
=
device
)
mat1
=
SparseTensor
(
data
.
edge_index
,
value
)
mat1
=
SparseTensor
(
data
.
edge_index
,
value
)
print
(
mat1
)
# print(mat1)
print
(
id
(
mat1
))
mat1
=
mat1
.
long
()
# # print(mat1.to_dense().size())
print
(
id
(
mat1
))
# print(mat1.to_torch_sparse_coo_tensor().to_dense().size())
mat1
=
mat1
.
long
()
# print(mat1.to_scipy(layout='coo').todense().shape)
print
(
id
(
mat1
))
# print(mat1.to_scipy(layout='csr').todense().shape)
mat1
=
mat1
.
to
(
torch
.
bool
)
# print(mat1.to_scipy(layout='csc').todense().shape)
print
(
mat1
)
print
(
mat1
.
is_pinned
())
# print(mat1.is_quadratic())
# print(mat1.is_symmetric())
print
(
mat1
.
to_dense
().
size
())
# print(mat1.cached_keys())
mat2
=
mat1
.
to_torch_sparse_coo_tensor
()
# mat1 = mat1.t()
print
(
mat2
)
# print(mat1.cached_keys())
# mat1 = mat1.t()
print
(
mat1
.
to_scipy
(
layout
=
'coo'
).
todense
().
shape
)
# print(mat1.cached_keys())
print
(
mat1
.
to_scipy
(
layout
=
'csr'
).
todense
().
shape
)
# print('-------- NARROW ----------')
print
(
mat1
.
to_scipy
(
layout
=
'csc'
).
todense
().
shape
)
t
=
time
.
perf_counter
()
print
(
mat1
.
is_quadratic
())
for
_
in
range
(
100
):
print
(
mat1
.
is_symmetric
())
out
=
mat1
.
narrow
(
dim
=
0
,
start
=
10
,
length
=
10
)
# out._storage.colptr
mat1
=
mat1
.
t
()
print
(
time
.
perf_counter
()
-
t
)
print
(
mat1
)
print
(
out
)
print
(
out
.
cached_keys
())
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out
=
mat1
.
narrow
(
dim
=
1
,
start
=
10
,
length
=
2000
)
# out._storage.colptr
print
(
time
.
perf_counter
()
-
t
)
print
(
out
)
print
(
out
.
cached_keys
())
# mat1 = mat1.narrow(0, start=10, length=10)
# mat1._storage._value = torch.randn(mat1.nnz(), 20)
# print(mat1.coo()[1].size())
# mat1 = mat1.narrow(2, start=10, length=10)
# print(mat1.coo()[1].size())
# mat1 = mat1.t()
# mat1 = mat1.t()
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
...
...
torch_sparse/transpose.py
View file @
bb963e32
...
@@ -29,14 +29,17 @@ def transpose(index, value, m, n, coalesced=True):
...
@@ -29,14 +29,17 @@ def transpose(index, value, m, n, coalesced=True):
def
t
(
mat
):
def
t
(
mat
):
((
row
,
col
),
value
),
perm
=
mat
.
coo
(),
mat
.
_storage
.
csr_to_csc
(
row
,
col
),
value
=
mat
.
coo
()
csr_to_csc
=
mat
.
_storage
.
csr_to_csc
storage
=
mat
.
_storage
.
__class__
(
storage
=
mat
.
_storage
.
__class__
(
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)[:,
perm
],
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)[:,
csr_to_csc
],
value
=
value
[
perm
]
if
mat
.
has_value
()
else
None
,
value
=
value
[
csr_to_csc
]
if
mat
.
has_value
()
else
None
,
sparse_size
=
mat
.
sparse_size
()[::
-
1
],
sparse_size
=
mat
.
sparse_size
()[::
-
1
],
rowptr
=
mat
.
_storage
.
_colptr
,
rowptr
=
mat
.
_storage
.
_colptr
,
colptr
=
mat
.
_storage
.
_rowptr
,
colptr
=
mat
.
_storage
.
_rowptr
,
csr_to_csc
=
mat
.
_storage
.
_csc_to_csr
,
csr_to_csc
=
mat
.
_storage
.
_csc_to_csr
,
csc_to_csr
=
perm
,
csc_to_csr
=
csr_to_csc
,
is_sorted
=
True
)
is_sorted
=
True
)
return
mat
.
__class__
.
from_storage
(
storage
)
return
mat
.
__class__
.
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment