Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
a25d8256
"vscode:/vscode.git/clone" did not exist on "9a7ae77a4eda5b4f819fd22ce9b713fb79993201"
Commit
a25d8256
authored
Dec 17, 2019
by
rusty1s
Browse files
more functionality
parent
36d045fd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
150 additions
and
20 deletions
+150
-20
torch_sparse/sparse.py
torch_sparse/sparse.py
+150
-20
No files found.
torch_sparse/sparse.py
View file @
a25d8256
import
warnings
import
inspect
import
inspect
from
textwrap
import
indent
from
textwrap
import
indent
import
torch
import
torch
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
...
@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
...
@@ -8,11 +10,15 @@ methods = list(zip(*inspect.getmembers(SparseStorage)))[0]
methods
=
[
name
for
name
in
methods
if
'__'
not
in
name
and
name
!=
'clone'
]
methods
=
[
name
for
name
in
methods
if
'__'
not
in
name
and
name
!=
'clone'
]
def
__is_scalar__
(
x
):
return
isinstance
(
x
,
int
)
or
isinstance
(
x
,
float
)
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
self
.
_storage
=
SparseStorage
(
index
[
0
],
index
[
1
],
value
,
sparse_size
,
self
.
_storage
=
SparseStorage
(
is_sorted
=
is_sorted
)
index
[
0
],
index
[
1
],
value
,
sparse_size
,
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
def
from_storage
(
self
,
storage
):
def
from_storage
(
self
,
storage
):
...
@@ -78,16 +84,29 @@ class SparseTensor(object):
...
@@ -78,16 +84,29 @@ class SparseTensor(object):
value_symmetric
=
(
value1
==
value2
).
all
()
if
self
.
has_value
else
True
value_symmetric
=
(
value1
==
value2
).
all
()
if
self
.
has_value
else
True
return
index_symmetric
and
value_symmetric
return
index_symmetric
and
value_symmetric
def
set_value
(
self
,
value
,
layout
):
def
set_value
(
self
,
value
,
layout
=
None
):
if
layout
is
None
:
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
[
'coo'
,
'csr'
,
'csc'
]
if
value
is
not
None
and
layout
==
'csc'
:
if
value
is
not
None
and
layout
==
'csc'
:
value
=
value
[
self
.
_arg_csc_to_csr
]
value
=
value
[
self
.
_arg_csc_to_csr
]
return
self
.
_apply_value
(
value
)
return
self
.
_apply_value
(
value
)
def
set_value_
(
self
,
value
,
layout
):
def
set_value_
(
self
,
value
,
layout
=
None
):
if
layout
is
None
:
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
[
'coo'
,
'csr'
,
'csc'
]
if
value
is
not
None
and
layout
==
'csc'
:
if
value
is
not
None
and
layout
==
'csc'
:
value
=
value
[
self
.
_arg_csc_to_csr
]
value
=
value
[
self
.
_arg_csc_to_csr
]
return
self
.
_apply_value_
(
value
)
return
self
.
_apply_value_
(
value
)
def
set_diag
(
self
,
value
):
raise
NotImplementedError
def
t
(
self
):
def
t
(
self
):
storage
=
SparseStorage
(
storage
=
SparseStorage
(
self
.
_col
[
self
.
_arg_csr_to_csc
],
self
.
_col
[
self
.
_arg_csr_to_csc
],
...
@@ -102,22 +121,119 @@ class SparseTensor(object):
...
@@ -102,22 +121,119 @@ class SparseTensor(object):
)
)
return
self
.
__class__
.
from_storage
(
storage
)
return
self
.
__class__
.
from_storage
(
storage
)
def
matmul
(
self
,
mat2
):
raise
NotImplementedError
def
coalesce
(
self
,
reduce
=
'add'
):
def
coalesce
(
self
,
reduce
=
'add'
):
raise
NotImplementedError
raise
NotImplementedError
def
is_coalesced
(
self
):
def
is_coalesced
(
self
):
raise
NotImplementedError
raise
NotImplementedError
def
add
(
self
,
layout
=
None
):
def
masked_select
(
self
,
mask
):
# sub, mul, div
raise
NotImplementedError
# can take scalars, tensors and other sparse matrices
# inplace variants can only take scalars or tensors
def
index_select
(
self
,
index
):
raise
NotImplementedError
def
select
(
self
,
dim
,
index
):
raise
NotImplementedError
def
filter
(
self
,
index
):
assert
self
.
is_symmetric
assert
index
.
dtype
==
torch
.
long
or
index
.
dtype
==
torch
.
bool
raise
NotImplementedError
def
permute
(
self
,
index
):
assert
index
.
dtype
==
torch
.
long
return
self
.
filter
(
index
)
def
__getitem__
(
self
,
idx
):
# Convert int and slice to index tensor
# Filter list into edge and sparse slice
raise
NotImplementedError
def
__reduce
(
self
,
dim
,
reduce
,
only_nnz
):
raise
NotImplementedError
def
sum
(
self
,
dim
):
return
self
.
__reduce
(
dim
,
reduce
=
'add'
,
only_nnz
=
True
)
def
prod
(
self
,
dim
):
return
self
.
__reduce
(
dim
,
reduce
=
'mul'
,
only_nnz
=
True
)
def
min
(
self
,
dim
,
only_nnz
=
False
):
return
self
.
__reduce
(
dim
,
reduce
=
'min'
,
only_nnz
=
only_nnz
)
def
max
(
self
,
dim
,
only_nnz
=
False
):
return
self
.
__reduce
(
dim
,
reduce
=
'min'
,
only_nnz
=
only_nnz
)
def
mean
(
self
,
dim
,
only_nnz
=
False
):
return
self
.
__reduce
(
dim
,
reduce
=
'mean'
,
only_nnz
=
only_nnz
)
def
matmul
(
self
,
mat
,
reduce
=
'add'
):
assert
self
.
numel
()
==
self
.
nnz
()
# Disallow multi-dimensional value
if
torch
.
is_tensor
(
mat
):
raise
NotImplementedError
elif
isinstance
(
mat
,
self
.
__class__
):
assert
reduce
==
'add'
assert
mat
.
numel
()
==
mat
.
nnz
()
# Disallow multi-dimensional value
raise
NotImplementedError
raise
ValueError
(
'Argument needs to be of type `torch.tensor` or '
'type `torch_sparse.SparseTensor`.'
)
def
add
(
self
,
other
,
layout
=
None
):
if
__is_scalar__
(
other
):
if
self
.
has_value
:
return
self
.
set_value
(
self
.
_value
+
other
,
'coo'
)
else
:
return
self
.
set_value
(
torch
.
full
((
self
.
nnz
(),
),
other
+
1
),
'coo'
)
elif
torch
.
is_tensor
(
other
):
if
layout
is
None
:
layout
=
'coo'
warnings
.
warn
(
'`layout` argument unset, using default layout '
'"coo". This may lead to unexpected behaviour.'
)
assert
layout
in
[
'coo'
,
'csr'
,
'csc'
]
if
layout
==
'csc'
:
other
=
other
[
self
.
_arg_csc_to_csr
]
if
self
.
has_value
:
return
self
.
set_value
(
self
.
_value
+
other
,
'coo'
)
else
:
return
self
.
set_value
(
other
+
1
,
'coo'
)
elif
isinstance
(
other
,
self
.
__class__
):
raise
NotImplementedError
raise
NotImplementedError
raise
ValueError
(
'Argument needs to be of type `int`, `float`, '
'`torch.tensor` or `torch_sparse.SparseTensor`.'
)
# TODO: Slicing, (sum|max|min|prod|...), standard operators, masing, perm
def
add_
(
self
,
other
,
layout
=
None
):
if
isinstance
(
other
,
int
)
or
isinstance
(
other
,
float
):
raise
NotImplementedError
elif
torch
.
is_tensor
(
other
):
raise
NotImplementedError
raise
ValueError
(
'Argument needs to be a scalar or of type '
'`torch.tensor`.'
)
def
__add__
(
self
,
other
):
return
self
.
add
(
other
)
def
__radd__
(
self
,
other
):
return
self
.
add
(
other
)
def
sub
(
self
,
layout
=
None
):
raise
NotImplementedError
def
sub_
(
self
,
layout
=
None
):
raise
NotImplementedError
def
mul
(
self
,
layout
=
None
):
raise
NotImplementedError
def
mul_
(
self
,
layout
=
None
):
raise
NotImplementedError
def
div
(
self
,
layout
=
None
):
raise
NotImplementedError
def
div_
(
self
,
layout
=
None
):
raise
NotImplementedError
def
to_dense
(
self
,
dtype
=
None
):
def
to_dense
(
self
,
dtype
=
None
):
dtype
=
dtype
or
self
.
dtype
dtype
=
dtype
or
self
.
dtype
...
@@ -125,11 +241,17 @@ class SparseTensor(object):
...
@@ -125,11 +241,17 @@ class SparseTensor(object):
mat
[
self
.
_row
,
self
.
_col
]
=
self
.
_value
if
self
.
has_value
else
1
mat
[
self
.
_row
,
self
.
_col
]
=
self
.
_value
if
self
.
has_value
else
1
return
mat
return
mat
def
to_scipy
(
self
):
def
to_scipy
(
self
,
layout
):
raise
NotImplementedError
raise
NotImplementedError
def
to_torch_sparse_coo_tensor
(
self
):
def
to_torch_sparse_coo_tensor
(
self
,
dtype
=
None
,
requires_grad
=
False
):
raise
NotImplementedError
index
,
value
=
self
.
coo
()
return
torch
.
sparse_coo_tensor
(
index
,
torch
.
ones_like
(
self
.
_row
,
dtype
)
if
value
is
None
else
value
,
self
.
size
(),
device
=
self
.
device
,
requires_grad
=
requires_grad
)
def
__repr__
(
self
):
def
__repr__
(
self
):
i
=
' '
*
6
i
=
' '
*
6
...
@@ -156,7 +278,8 @@ if __name__ == '__main__':
...
@@ -156,7 +278,8 @@ if __name__ == '__main__':
device
=
'cpu'
device
=
'cpu'
# dataset = Reddit('/tmp/Reddit')
# dataset = Reddit('/tmp/Reddit')
dataset
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
# dataset = Planetoid('/tmp/PubMed', 'PubMed')
data
=
dataset
[
0
].
to
(
device
)
data
=
dataset
[
0
].
to
(
device
)
_bytes
=
data
.
edge_index
.
numel
()
*
8
_bytes
=
data
.
edge_index
.
numel
()
*
8
...
@@ -169,8 +292,8 @@ if __name__ == '__main__':
...
@@ -169,8 +292,8 @@ if __name__ == '__main__':
print
(
mat1
)
print
(
mat1
)
mat1
=
mat1
.
t
()
mat1
=
mat1
.
t
()
mat2
=
torch
.
sparse_coo_tensor
(
data
.
edge_index
,
torch
.
ones
(
data
.
num_edges
),
mat2
=
torch
.
sparse_coo_tensor
(
device
=
device
)
data
.
edge_index
,
torch
.
ones
(
data
.
num_edges
),
device
=
device
)
mat2
=
mat2
.
coalesce
()
mat2
=
mat2
.
coalesce
()
mat2
=
mat2
.
t
().
coalesce
()
mat2
=
mat2
.
t
().
coalesce
()
...
@@ -182,5 +305,12 @@ if __name__ == '__main__':
...
@@ -182,5 +305,12 @@ if __name__ == '__main__':
out2
=
mat2
.
to_dense
()
out2
=
mat2
.
to_dense
()
assert
torch
.
allclose
(
out1
,
out2
)
assert
torch
.
allclose
(
out1
,
out2
)
mat1
=
SparseTensor
.
from_dense
(
out1
)
out
=
2
+
mat1
print
(
mat1
)
print
(
out
)
# mat1[1]
# mat1[1, 1]
# mat1[..., -1]
# mat1[:, -1]
# mat1[1:4, 1:4]
# mat1[torch.tensor([0, 1, 2])]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment