Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0e541cc9
Unverified
Commit
0e541cc9
authored
Apr 12, 2020
by
bwdeng20
Committed by
GitHub
Apr 12, 2020
Browse files
Merge branch 'master' into master
parents
0090f4ed
056c0bab
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
50 deletions
+129
-50
torch_sparse/saint.py
torch_sparse/saint.py
+25
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+104
-50
No files found.
torch_sparse/saint.py
0 → 100644
View file @
0e541cc9
from
typing
import
Tuple
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
saint_subgraph
(
src
:
SparseTensor
,
node_idx
:
torch
.
Tensor
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
value
=
src
.
coo
()
rowptr
=
src
.
storage
.
rowptr
()
data
=
torch
.
ops
.
torch_sparse
.
saint_subgraph
(
node_idx
,
rowptr
,
row
,
col
)
row
,
col
,
edge_index
=
data
if
value
is
not
None
:
value
=
value
[
edge_index
]
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
node_idx
.
size
(
0
),
node_idx
.
size
(
0
)),
is_sorted
=
True
)
return
out
,
edge_index
SparseTensor
.
saint_subgraph
=
saint_subgraph
torch_sparse/tensor.py
View file @
0e541cc9
...
...
@@ -12,17 +12,25 @@ from torch_sparse.utils import is_scalar
class
SparseTensor
(
object
):
storage
:
SparseStorage
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
is_sorted
:
bool
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
is_sorted
)
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
is_sorted
)
@
classmethod
def
from_storage
(
self
,
storage
:
SparseStorage
):
...
...
@@ -45,12 +53,17 @@ class SparseTensor(object):
if
has_value
:
value
=
mat
[
row
,
col
]
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
,
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
,
has_value
:
bool
=
True
):
mat
=
mat
.
coalesce
()
index
=
mat
.
_indices
()
...
...
@@ -60,13 +73,20 @@ class SparseTensor(object):
if
has_value
:
value
=
mat
.
_values
()
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
has_value
:
bool
=
True
,
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
has_value
:
bool
=
True
,
fill_cache
:
bool
=
False
):
N
=
M
if
N
is
None
else
N
...
...
@@ -84,8 +104,8 @@ class SparseTensor(object):
value
:
Optional
[
torch
.
Tensor
]
=
None
if
has_value
:
if
options
is
not
None
:
value
=
torch
.
ones
(
row
.
numel
(),
dtype
=
options
.
dtype
,
device
=
row
.
device
)
value
=
torch
.
ones
(
row
.
numel
(),
dtype
=
options
.
dtype
,
device
=
row
.
device
)
else
:
value
=
torch
.
ones
(
row
.
numel
(),
device
=
row
.
device
)
...
...
@@ -108,9 +128,17 @@ class SparseTensor(object):
csr2csc
=
csc2csr
=
row
storage
:
SparseStorage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
storage
=
storage
...
...
@@ -153,12 +181,14 @@ class SparseTensor(object):
def
has_value
(
self
)
->
bool
:
return
self
.
storage
.
has_value
()
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
self
.
storage
.
set_value_
(
value
,
layout
)
return
self
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
...
...
@@ -187,23 +217,31 @@ class SparseTensor(object):
# Utility functions #######################################################
def
fill_value_
(
self
,
fill_value
:
float
,
def
fill_value_
(
self
,
fill_value
:
float
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
options
is
not
None
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
device
=
self
.
device
())
return
self
.
set_value_
(
value
,
layout
=
'coo'
)
def
fill_value
(
self
,
fill_value
:
float
,
def
fill_value
(
self
,
fill_value
:
float
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
options
is
not
None
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
device
=
self
.
device
())
return
self
.
set_value
(
value
,
layout
=
'coo'
)
...
...
@@ -270,8 +308,13 @@ class SparseTensor(object):
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
out
=
out
.
coalesce
(
reduce
)
return
out
...
...
@@ -294,7 +337,8 @@ class SparseTensor(object):
else
:
return
False
def
requires_grad_
(
self
,
requires_grad
:
bool
=
True
,
def
requires_grad_
(
self
,
requires_grad
:
bool
=
True
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
requires_grad
and
not
self
.
has_value
():
self
.
fill_value_
(
1.
,
options
=
options
)
...
...
@@ -315,8 +359,8 @@ class SparseTensor(object):
if
value
is
not
None
:
return
value
else
:
return
torch
.
tensor
(
0.
,
dtype
=
torch
.
float
,
device
=
self
.
storage
.
col
().
device
)
return
torch
.
tensor
(
0.
,
dtype
=
torch
.
float
,
device
=
self
.
storage
.
col
().
device
)
def
device
(
self
):
return
self
.
storage
.
col
().
device
...
...
@@ -324,7 +368,8 @@ class SparseTensor(object):
def
cpu
(
self
):
return
self
.
device_as
(
torch
.
tensor
(
0.
),
non_blocking
=
False
)
def
cuda
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
def
cuda
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
non_blocking
:
bool
=
False
):
if
options
is
not
None
:
return
self
.
device_as
(
options
,
non_blocking
)
...
...
@@ -387,19 +432,19 @@ class SparseTensor(object):
row
,
col
,
value
=
self
.
coo
()
if
value
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
value
.
dtype
,
device
=
self
.
device
())
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
value
.
dtype
,
device
=
self
.
device
())
elif
options
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
device
=
self
.
device
())
if
value
is
not
None
:
mat
[
row
,
col
]
=
value
else
:
mat
[
row
,
col
]
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
mat
.
dtype
,
device
=
mat
.
device
)
mat
[
row
,
col
]
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
mat
.
dtype
,
device
=
mat
.
device
)
return
mat
...
...
@@ -409,8 +454,8 @@ class SparseTensor(object):
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
value
is
None
:
if
options
is
not
None
:
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
value
=
torch
.
ones
(
self
.
nnz
(),
device
=
self
.
device
())
...
...
@@ -434,7 +479,7 @@ def is_shared(self: SparseTensor) -> bool:
def
to
(
self
,
*
args
:
Optional
[
List
[
Any
]],
**
kwargs
:
Optional
[
Dict
[
str
,
Any
]])
->
SparseTensor
:
device
,
dtype
,
non_blocking
,
_
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
device
,
dtype
,
non_blocking
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
dtype
is
not
None
:
self
=
self
.
type_as
(
torch
.
tensor
(
0.
,
dtype
=
dtype
))
...
...
@@ -515,8 +560,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
@
torch
.
jit
.
ignore
...
...
@@ -535,16 +580,25 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
value
=
torch
.
from_numpy
(
mat
.
data
)
sparse_sizes
=
mat
.
shape
[:
2
]
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
@
torch
.
jit
.
ignore
def
to_scipy
(
self
:
SparseTensor
,
layout
:
Optional
[
str
]
=
None
,
def
to_scipy
(
self
:
SparseTensor
,
layout
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
ScipySparseMatrix
:
assert
self
.
dim
()
==
2
layout
=
get_layout
(
layout
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment