Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
195053ef
Commit
195053ef
authored
Dec 18, 2019
by
rusty1s
Browse files
transpose
parent
76bf1e8a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
19 deletions
+36
-19
torch_sparse/storage.py
torch_sparse/storage.py
+1
-0
torch_sparse/tensor.py
torch_sparse/tensor.py
+21
-19
torch_sparse/transpose.py
torch_sparse/transpose.py
+14
-0
No files found.
torch_sparse/storage.py
View file @
195053ef
...
@@ -36,6 +36,7 @@ class SparseStorage(object):
...
@@ -36,6 +36,7 @@ class SparseStorage(object):
assert
index
.
dtype
==
torch
.
long
assert
index
.
dtype
==
torch
.
long
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
index
=
index
.
contiguous
()
if
value
is
not
None
:
if
value
is
not
None
:
assert
value
.
device
==
index
.
device
assert
value
.
device
==
index
.
device
...
...
torch_sparse/tensor.py
View file @
195053ef
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
import
scipy.sparse
import
scipy.sparse
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.storage
import
SparseStorage
from
torch_sparse.transpose
import
t
class
SparseTensor
(
object
):
class
SparseTensor
(
object
):
...
@@ -128,9 +129,9 @@ class SparseTensor(object):
...
@@ -128,9 +129,9 @@ class SparseTensor(object):
rowptr
,
col
,
val1
=
self
.
csr
()
rowptr
,
col
,
val1
=
self
.
csr
()
colptr
,
row
,
val2
=
self
.
csc
()
colptr
,
row
,
val2
=
self
.
csc
()
index_sym
metric
=
(
rowptr
==
colptr
).
all
()
and
(
col
==
row
).
all
()
index_sym
=
(
rowptr
==
colptr
).
all
()
and
(
col
==
row
).
all
()
value_sym
metric
=
(
val1
==
val2
).
all
()
if
self
.
has_value
()
else
True
value_sym
=
(
val1
==
val2
).
all
()
.
item
()
if
self
.
has_value
()
else
True
return
index_sym
metric
and
value_sym
metric
return
index_sym
.
item
()
and
value_sym
def
detach_
(
self
):
def
detach_
(
self
):
self
.
_storage
.
apply_
(
lambda
x
:
x
.
detach_
())
self
.
_storage
.
apply_
(
lambda
x
:
x
.
detach_
())
...
@@ -310,23 +311,11 @@ class SparseTensor(object):
...
@@ -310,23 +311,11 @@ class SparseTensor(object):
# Bindings ####################################################################
# Bindings ####################################################################
SparseTensor
.
t
=
t
# def set_diag(self, value):
# def set_diag(self, value):
# raise NotImplementedError
# raise NotImplementedError
# def t(self):
# storage = SparseStorage(
# self._col[self._arg_csr_to_csc],
# self._row[self._arg_csr_to_csc],
# self._value[self._arg_csr_to_csc] if self.has_value else None,
# self.sparse_size()[::-1],
# self._colptr,
# self._rowptr,
# self._arg_csc_to_csr,
# self._arg_csr_to_csc,
# is_sorted=True,
# )
# return self.__class__.from_storage(storage)
#
# def masked_select(self, mask):
# def masked_select(self, mask):
# raise NotImplementedError
# raise NotImplementedError
...
@@ -446,10 +435,17 @@ if __name__ == '__main__':
...
@@ -446,10 +435,17 @@ if __name__ == '__main__':
data
=
dataset
[
0
].
to
(
device
)
data
=
dataset
[
0
].
to
(
device
)
value
=
torch
.
ones
((
data
.
num_edges
,
),
device
=
device
)
value
=
torch
.
ones
((
data
.
num_edges
,
),
device
=
device
)
value
=
None
mat1
=
SparseTensor
(
data
.
edge_index
,
value
)
mat1
=
SparseTensor
(
data
.
edge_index
,
value
)
print
(
mat1
)
print
(
mat1
)
print
(
id
(
mat1
))
mat1
=
mat1
.
long
()
print
(
id
(
mat1
))
mat1
=
mat1
.
long
()
print
(
id
(
mat1
))
mat1
=
mat1
.
to
(
torch
.
bool
)
print
(
mat1
)
print
(
mat1
.
is_pinned
())
print
(
mat1
.
to_dense
().
size
())
print
(
mat1
.
to_dense
().
size
())
...
@@ -459,9 +455,15 @@ if __name__ == '__main__':
...
@@ -459,9 +455,15 @@ if __name__ == '__main__':
print
(
mat1
.
to_scipy
(
layout
=
'coo'
).
todense
().
shape
)
print
(
mat1
.
to_scipy
(
layout
=
'coo'
).
todense
().
shape
)
print
(
mat1
.
to_scipy
(
layout
=
'csr'
).
todense
().
shape
)
print
(
mat1
.
to_scipy
(
layout
=
'csr'
).
todense
().
shape
)
print
(
mat1
.
to_scipy
(
layout
=
'csc'
).
todense
().
shape
)
print
(
mat1
.
to_scipy
(
layout
=
'csc'
).
todense
().
shape
)
print
(
mat1
.
is_quadratic
())
print
(
mat1
.
is_symmetric
())
mat1
=
mat1
.
t
()
print
(
mat1
)
# mat1 = mat1.t()
# mat1 = mat1.t()
#
mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
# mat2 = torch.sparse_coo_tensor(data.edge_index, torch.ones(data.num_edges),
# device=device)
# device=device)
# mat2 = mat2.coalesce()
# mat2 = mat2.coalesce()
# mat2 = mat2.t().coalesce()
# mat2 = mat2.t().coalesce()
...
...
torch_sparse/transpose.py
View file @
195053ef
...
@@ -26,3 +26,17 @@ def transpose(index, value, m, n, coalesced=True):
...
@@ -26,3 +26,17 @@ def transpose(index, value, m, n, coalesced=True):
if
coalesced
:
if
coalesced
:
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
index
,
value
=
coalesce
(
index
,
value
,
n
,
m
)
return
index
,
value
return
index
,
value
def
t
(
mat
):
((
row
,
col
),
value
),
perm
=
mat
.
coo
(),
mat
.
_storage
.
csr_to_csc
storage
=
mat
.
_storage
.
__class__
(
index
=
torch
.
stack
([
col
,
row
],
dim
=
0
)[:,
perm
],
value
=
value
[
perm
]
if
mat
.
has_value
()
else
None
,
sparse_size
=
mat
.
sparse_size
()[::
-
1
],
rowptr
=
mat
.
_storage
.
_colptr
,
colptr
=
mat
.
_storage
.
_rowptr
,
csr_to_csc
=
mat
.
_storage
.
_csc_to_csr
,
csc_to_csr
=
perm
,
is_sorted
=
True
)
return
mat
.
__class__
.
from_storage
(
storage
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment