Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0e541cc9
Unverified
Commit
0e541cc9
authored
Apr 12, 2020
by
bwdeng20
Committed by
GitHub
Apr 12, 2020
Browse files
Merge branch 'master' into master
parents
0090f4ed
056c0bab
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
50 deletions
+129
-50
torch_sparse/saint.py
torch_sparse/saint.py
+25
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+104
-50
No files found.
torch_sparse/saint.py
0 → 100644
View file @
0e541cc9
from
typing
import
Tuple
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
saint_subgraph
(
src
:
SparseTensor
,
node_idx
:
torch
.
Tensor
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
value
=
src
.
coo
()
rowptr
=
src
.
storage
.
rowptr
()
data
=
torch
.
ops
.
torch_sparse
.
saint_subgraph
(
node_idx
,
rowptr
,
row
,
col
)
row
,
col
,
edge_index
=
data
if
value
is
not
None
:
value
=
value
[
edge_index
]
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
node_idx
.
size
(
0
),
node_idx
.
size
(
0
)),
is_sorted
=
True
)
return
out
,
edge_index
SparseTensor
.
saint_subgraph
=
saint_subgraph
torch_sparse/tensor.py
View file @
0e541cc9
...
@@ -12,17 +12,25 @@ from torch_sparse.utils import is_scalar
...
@@ -12,17 +12,25 @@ from torch_sparse.utils import is_scalar
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
storage
:
SparseStorage
storage
:
SparseStorage
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
is_sorted
:
bool
=
False
):
is_sorted
:
bool
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
self
.
storage
=
SparseStorage
(
value
=
value
,
sparse_sizes
=
sparse_sizes
,
row
=
row
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
rowptr
=
rowptr
,
csr2csc
=
None
,
csc2csr
=
None
,
col
=
col
,
is_sorted
=
is_sorted
)
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
is_sorted
)
@
classmethod
@
classmethod
def
from_storage
(
self
,
storage
:
SparseStorage
):
def
from_storage
(
self
,
storage
:
SparseStorage
):
...
@@ -45,12 +53,17 @@ class SparseTensor(object):
...
@@ -45,12 +53,17 @@ class SparseTensor(object):
if
has_value
:
if
has_value
:
value
=
mat
[
row
,
col
]
value
=
mat
[
row
,
col
]
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
return
SparseTensor
(
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
row
=
row
,
is_sorted
=
True
)
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
@
classmethod
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
,
def
from_torch_sparse_coo_tensor
(
self
,
mat
:
torch
.
Tensor
,
has_value
:
bool
=
True
):
has_value
:
bool
=
True
):
mat
=
mat
.
coalesce
()
mat
=
mat
.
coalesce
()
index
=
mat
.
_indices
()
index
=
mat
.
_indices
()
...
@@ -60,13 +73,20 @@ class SparseTensor(object):
...
@@ -60,13 +73,20 @@ class SparseTensor(object):
if
has_value
:
if
has_value
:
value
=
mat
.
_values
()
value
=
mat
.
_values
()
return
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
return
SparseTensor
(
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
row
=
row
,
is_sorted
=
True
)
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
mat
.
size
(
0
),
mat
.
size
(
1
)),
is_sorted
=
True
)
@
classmethod
@
classmethod
def
eye
(
self
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
def
eye
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
has_value
:
bool
=
True
,
M
:
int
,
N
:
Optional
[
int
]
=
None
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
has_value
:
bool
=
True
,
fill_cache
:
bool
=
False
):
fill_cache
:
bool
=
False
):
N
=
M
if
N
is
None
else
N
N
=
M
if
N
is
None
else
N
...
@@ -84,8 +104,8 @@ class SparseTensor(object):
...
@@ -84,8 +104,8 @@ class SparseTensor(object):
value
:
Optional
[
torch
.
Tensor
]
=
None
value
:
Optional
[
torch
.
Tensor
]
=
None
if
has_value
:
if
has_value
:
if
options
is
not
None
:
if
options
is
not
None
:
value
=
torch
.
ones
(
row
.
numel
(),
dtype
=
options
.
dtype
,
value
=
torch
.
ones
(
device
=
row
.
device
)
row
.
numel
(),
dtype
=
options
.
dtype
,
device
=
row
.
device
)
else
:
else
:
value
=
torch
.
ones
(
row
.
numel
(),
device
=
row
.
device
)
value
=
torch
.
ones
(
row
.
numel
(),
device
=
row
.
device
)
...
@@ -108,9 +128,17 @@ class SparseTensor(object):
...
@@ -108,9 +128,17 @@ class SparseTensor(object):
csr2csc
=
csc2csr
=
row
csr2csc
=
csc2csr
=
row
storage
:
SparseStorage
=
SparseStorage
(
storage
:
SparseStorage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
rowptr
=
rowptr
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
(
M
,
N
),
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
)
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
=
SparseTensor
.
__new__
(
SparseTensor
)
self
.
storage
=
storage
self
.
storage
=
storage
...
@@ -153,12 +181,14 @@ class SparseTensor(object):
...
@@ -153,12 +181,14 @@ class SparseTensor(object):
def
has_value
(
self
)
->
bool
:
def
has_value
(
self
)
->
bool
:
return
self
.
storage
.
has_value
()
return
self
.
storage
.
has_value
()
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value_
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
self
.
storage
.
set_value_
(
value
,
layout
)
self
.
storage
.
set_value_
(
value
,
layout
)
return
self
return
self
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
def
set_value
(
self
,
value
:
Optional
[
torch
.
Tensor
],
layout
:
Optional
[
str
]
=
None
):
layout
:
Optional
[
str
]
=
None
):
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
return
self
.
from_storage
(
self
.
storage
.
set_value
(
value
,
layout
))
...
@@ -187,23 +217,31 @@ class SparseTensor(object):
...
@@ -187,23 +217,31 @@ class SparseTensor(object):
# Utility functions #######################################################
# Utility functions #######################################################
def
fill_value_
(
self
,
fill_value
:
float
,
def
fill_value_
(
self
,
fill_value
:
float
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
options
is
not
None
:
if
options
is
not
None
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
device
=
self
.
device
())
device
=
self
.
device
())
else
:
else
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
device
=
self
.
device
())
device
=
self
.
device
())
return
self
.
set_value_
(
value
,
layout
=
'coo'
)
return
self
.
set_value_
(
value
,
layout
=
'coo'
)
def
fill_value
(
self
,
fill_value
:
float
,
def
fill_value
(
self
,
fill_value
:
float
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
options
is
not
None
:
if
options
is
not
None
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
dtype
=
options
.
dtype
,
device
=
self
.
device
())
device
=
self
.
device
())
else
:
else
:
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
value
=
torch
.
full
((
self
.
nnz
(),
),
fill_value
,
device
=
self
.
device
())
device
=
self
.
device
())
return
self
.
set_value
(
value
,
layout
=
'coo'
)
return
self
.
set_value
(
value
,
layout
=
'coo'
)
...
@@ -270,8 +308,13 @@ class SparseTensor(object):
...
@@ -270,8 +308,13 @@ class SparseTensor(object):
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
N
=
max
(
self
.
size
(
0
),
self
.
size
(
1
))
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
out
=
SparseTensor
(
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
(
N
,
N
),
is_sorted
=
False
)
out
=
out
.
coalesce
(
reduce
)
out
=
out
.
coalesce
(
reduce
)
return
out
return
out
...
@@ -294,7 +337,8 @@ class SparseTensor(object):
...
@@ -294,7 +337,8 @@ class SparseTensor(object):
else
:
else
:
return
False
return
False
def
requires_grad_
(
self
,
requires_grad
:
bool
=
True
,
def
requires_grad_
(
self
,
requires_grad
:
bool
=
True
,
options
:
Optional
[
torch
.
Tensor
]
=
None
):
options
:
Optional
[
torch
.
Tensor
]
=
None
):
if
requires_grad
and
not
self
.
has_value
():
if
requires_grad
and
not
self
.
has_value
():
self
.
fill_value_
(
1.
,
options
=
options
)
self
.
fill_value_
(
1.
,
options
=
options
)
...
@@ -315,8 +359,8 @@ class SparseTensor(object):
...
@@ -315,8 +359,8 @@ class SparseTensor(object):
if
value
is
not
None
:
if
value
is
not
None
:
return
value
return
value
else
:
else
:
return
torch
.
tensor
(
0.
,
dtype
=
torch
.
float
,
return
torch
.
tensor
(
device
=
self
.
storage
.
col
().
device
)
0.
,
dtype
=
torch
.
float
,
device
=
self
.
storage
.
col
().
device
)
def
device
(
self
):
def
device
(
self
):
return
self
.
storage
.
col
().
device
return
self
.
storage
.
col
().
device
...
@@ -324,7 +368,8 @@ class SparseTensor(object):
...
@@ -324,7 +368,8 @@ class SparseTensor(object):
def
cpu
(
self
):
def
cpu
(
self
):
return
self
.
device_as
(
torch
.
tensor
(
0.
),
non_blocking
=
False
)
return
self
.
device_as
(
torch
.
tensor
(
0.
),
non_blocking
=
False
)
def
cuda
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
def
cuda
(
self
,
options
:
Optional
[
torch
.
Tensor
]
=
None
,
non_blocking
:
bool
=
False
):
non_blocking
:
bool
=
False
):
if
options
is
not
None
:
if
options
is
not
None
:
return
self
.
device_as
(
options
,
non_blocking
)
return
self
.
device_as
(
options
,
non_blocking
)
...
@@ -387,19 +432,19 @@ class SparseTensor(object):
...
@@ -387,19 +432,19 @@ class SparseTensor(object):
row
,
col
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
if
value
is
not
None
:
if
value
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
value
.
dtype
,
mat
=
torch
.
zeros
(
device
=
self
.
device
())
self
.
sizes
(),
dtype
=
value
.
dtype
,
device
=
self
.
device
())
elif
options
is
not
None
:
elif
options
is
not
None
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
dtype
=
options
.
dtype
,
mat
=
torch
.
zeros
(
device
=
self
.
device
())
self
.
sizes
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
else
:
mat
=
torch
.
zeros
(
self
.
sizes
(),
device
=
self
.
device
())
mat
=
torch
.
zeros
(
self
.
sizes
(),
device
=
self
.
device
())
if
value
is
not
None
:
if
value
is
not
None
:
mat
[
row
,
col
]
=
value
mat
[
row
,
col
]
=
value
else
:
else
:
mat
[
row
,
col
]
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
mat
.
dtype
,
mat
[
row
,
col
]
=
torch
.
ones
(
device
=
mat
.
device
)
self
.
nnz
(),
dtype
=
mat
.
dtype
,
device
=
mat
.
device
)
return
mat
return
mat
...
@@ -409,8 +454,8 @@ class SparseTensor(object):
...
@@ -409,8 +454,8 @@ class SparseTensor(object):
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
if
value
is
None
:
if
value
is
None
:
if
options
is
not
None
:
if
options
is
not
None
:
value
=
torch
.
ones
(
self
.
nnz
(),
dtype
=
options
.
dtype
,
value
=
torch
.
ones
(
device
=
self
.
device
())
self
.
nnz
(),
dtype
=
options
.
dtype
,
device
=
self
.
device
())
else
:
else
:
value
=
torch
.
ones
(
self
.
nnz
(),
device
=
self
.
device
())
value
=
torch
.
ones
(
self
.
nnz
(),
device
=
self
.
device
())
...
@@ -434,7 +479,7 @@ def is_shared(self: SparseTensor) -> bool:
...
@@ -434,7 +479,7 @@ def is_shared(self: SparseTensor) -> bool:
def
to
(
self
,
*
args
:
Optional
[
List
[
Any
]],
def
to
(
self
,
*
args
:
Optional
[
List
[
Any
]],
**
kwargs
:
Optional
[
Dict
[
str
,
Any
]])
->
SparseTensor
:
**
kwargs
:
Optional
[
Dict
[
str
,
Any
]])
->
SparseTensor
:
device
,
dtype
,
non_blocking
,
_
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
device
,
dtype
,
non_blocking
=
torch
.
_C
.
_nn
.
_parse_to
(
*
args
,
**
kwargs
)
if
dtype
is
not
None
:
if
dtype
is
not
None
:
self
=
self
.
type_as
(
torch
.
tensor
(
0.
,
dtype
=
dtype
))
self
=
self
.
type_as
(
torch
.
tensor
(
0.
,
dtype
=
dtype
))
...
@@ -515,8 +560,8 @@ SparseTensor.__repr__ = __repr__
...
@@ -515,8 +560,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
# Scipy Conversions ###########################################################
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
scipy
.
sparse
.
csc_matrix
]
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
...
@@ -535,16 +580,25 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
...
@@ -535,16 +580,25 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
value
=
torch
.
from_numpy
(
mat
.
data
)
value
=
torch
.
from_numpy
(
mat
.
data
)
sparse_sizes
=
mat
.
shape
[:
2
]
sparse_sizes
=
mat
.
shape
[:
2
]
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
storage
=
SparseStorage
(
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
row
=
row
,
colptr
=
colptr
,
colcount
=
None
,
csr2csc
=
None
,
rowptr
=
rowptr
,
csc2csr
=
None
,
is_sorted
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
rowcount
=
None
,
colptr
=
colptr
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
)
return
SparseTensor
.
from_storage
(
storage
)
return
SparseTensor
.
from_storage
(
storage
)
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
to_scipy
(
self
:
SparseTensor
,
layout
:
Optional
[
str
]
=
None
,
def
to_scipy
(
self
:
SparseTensor
,
layout
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
ScipySparseMatrix
:
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
ScipySparseMatrix
:
assert
self
.
dim
()
==
2
assert
self
.
dim
()
==
2
layout
=
get_layout
(
layout
)
layout
=
get_layout
(
layout
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment