Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
3307f0d9
Commit
3307f0d9
authored
Dec 15, 2019
by
rusty1s
Browse files
first sparse adj test
parent
2984f288
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
335 additions
and
0 deletions
+335
-0
torch_sparse/sparse.py
torch_sparse/sparse.py
+107
-0
torch_sparse/storage.py
torch_sparse/storage.py
+228
-0
No files found.
torch_sparse/sparse.py
0 → 100644
View file @
3307f0d9
from
textwrap
import
indent
import
torch
class
SparseTensor
(
object
):
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
is_sorted
=
False
):
assert
index
.
dtype
==
torch
.
long
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
index
=
index
.
contiguous
()
if
value
is
not
None
:
assert
value
.
size
(
0
)
==
index
.
size
(
1
)
assert
index
.
device
==
value
.
device
value
=
value
.
contiguous
()
if
sparse_size
is
None
:
sparse_size
=
torch
.
Size
((
index
.
max
(
dim
=-
1
)[
0
].
cpu
()
+
1
).
tolist
())
self
.
__index__
=
index
self
.
__value__
=
value
self
.
__sparse_size__
=
sparse_size
if
not
is_sorted
and
not
self
.
__is_sorted__
():
self
.
__sort__
()
def
to
(
*
args
,
**
kwargs
):
# TODO
pass
def
size
(
self
,
dim
=
None
):
size
=
self
.
__sparse_size__
size
+=
()
if
self
.
__value__
is
None
else
self
.
__value__
.
size
()[
1
:]
return
size
if
dim
is
None
else
size
[
dim
]
def
storage
(
self
):
pass
@
property
def
shape
(
self
):
return
self
.
size
()
def
dim
(
self
):
return
len
(
self
.
size
())
@
property
def
dtype
(
self
):
return
None
if
self
.
__value__
is
None
else
self
.
__value__
.
dtype
@
property
def
device
(
self
):
return
self
.
__index__
.
device
def
nnz
(
self
):
return
self
.
__index__
.
size
(
1
)
def
numel
(
self
):
return
self
.
__value__
.
numel
()
if
self
.
__value__
else
self
.
nnz
()
def
clone
(
self
):
return
self
.
__class__
(
index
=
self
.
__index__
.
clone
(),
value
=
None
if
self
.
__value__
is
None
else
self
.
__value__
.
clone
(),
sparse_size
=
self
.
__sparse_size__
,
is_sorted
=
True
,
)
def
sparse_resize_
(
self
,
*
sizes
):
assert
len
(
sizes
)
==
2
self
.
__sparse_size__
=
torch
.
Size
(
sizes
)
def
__is_sorted__
(
self
):
idx1
=
self
.
size
(
1
)
*
index
[
0
]
+
index
[
1
]
idx2
=
torch
.
cat
([
idx1
.
new_zeros
(
1
),
idx1
[:
-
1
]],
dim
=
0
)
return
(
idx1
>=
idx2
).
all
().
item
()
def
__sort__
(
self
):
idx
=
self
.
__sparse_size__
(
1
)
*
self
.
__index__
[
0
]
+
self
.
__index__
[
1
]
perm
=
idx
.
argsort
()
self
.
__index__
=
index
[:,
perm
]
self
.
__value__
=
None
if
self
.
__value__
is
None
else
self
.
__value__
[
perm
]
def
__repr__
(
self
):
i
=
' '
*
6
infos
=
[
f
'index=
{
indent
(
self
.
__index__
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
if
self
.
__value__
is
not
None
:
infos
+=
[
f
'value=
{
indent
(
self
.
__value__
.
__repr__
(),
i
)[
len
(
i
):]
}
'
]
infos
+=
[
f
'size=
{
tuple
(
self
.
size
())
}
, nnz=
{
self
.
nnz
()
}
'
]
infos
=
',
\n
'
.
join
(
infos
)
i
=
' '
*
(
len
(
self
.
__class__
.
__name__
)
+
1
)
return
f
'
{
self
.
__class__
.
__name__
}
(
{
indent
(
infos
,
i
)[
len
(
i
):]
}
)'
if
__name__
==
'__main__'
:
index
=
torch
.
tensor
([
[
0
,
0
,
1
,
1
,
2
,
2
],
[
2
,
1
,
2
,
3
,
0
,
1
],
])
value
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
=
torch
.
float
)
mat1
=
SparseTensor
(
index
,
value
)
print
(
mat1
)
mat2
=
torch
.
sparse_coo_tensor
(
index
,
value
)
# print(mat2)
torch_sparse/storage.py
0 → 100644
View file @
3307f0d9
import
torch
from
torch
import
Size
from
torch_scatter
import
scatter_add
,
segment_add
class
SparseStorage
(
object
):
def
__init__
(
self
,
row
,
col
,
value
=
None
,
sparse_size
=
None
,
rowptr
=
None
,
colptr
=
None
,
arg_csr_to_csc
=
None
,
arg_csc_to_csr
=
None
,
is_sorted
=
False
):
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
assert
row
.
device
==
row
.
device
assert
row
.
dim
()
==
1
and
col
.
dim
()
==
1
and
row
.
numel
()
==
col
.
numel
()
if
not
is_sorted
:
# Sort row and col
rowptr
=
None
colptr
=
None
arg_csr_to_csc
=
None
arg_csc_to_csr
=
None
if
value
is
not
None
:
assert
row
.
device
==
value
.
device
and
value
.
size
(
0
)
==
row
.
size
(
0
)
value
=
value
.
contiguous
()
if
sparse_size
is
None
:
sparse_size
=
Size
((
row
[
-
1
].
item
()
+
1
,
col
.
max
().
item
()
+
1
))
ones
=
None
if
rowptr
is
None
:
ones
=
torch
.
ones_like
(
row
)
rowptr
=
segment_add
(
ones
,
row
,
dim
=
0
,
dim_size
=
sparse_size
[
0
])
if
colptr
is
None
:
ones
=
torch
.
ones_like
(
col
)
if
ones
is
None
else
ones
colptr
=
scatter_add
(
ones
,
col
,
dim
=
0
,
dim_size
=
sparse_size
[
1
])
if
arg_csr_to_csc
is
None
:
idx
=
sparse_size
[
0
]
*
col
+
row
arg_csr_to_csc
=
idx
.
argsort
()
if
arg_csr_to_csc
is
None
:
arg_csc_to_csr
=
arg_csr_to_csc
.
argsort
()
self
.
__row
=
row
self
.
__col
=
col
self
.
__value
=
value
self
.
__sparse_size
=
sparse_size
self
.
__rowptr
=
rowptr
self
.
__colptr
=
colptr
self
.
__arg_csr_to_csc
=
arg_csr_to_csc
self
.
__arg_csc_to_csr
=
arg_csc_to_csr
@
property
def
row
(
self
):
return
self
.
__row
@
property
def
col
(
self
):
return
self
.
__col
def
index
(
self
):
return
torch
.
stack
([
self
.
__row
,
self
.
__col
],
dim
=
0
)
@
property
def
rowptr
(
self
):
return
self
.
__rowptr
@
property
def
colptr
(
self
):
return
self
.
__colptr
@
property
def
arg_csr_to_csc
(
self
):
return
self
.
__arg_csr_to_csc
@
property
def
arg_csc_to_csr
(
self
):
return
self
.
__arg_csc_to_csr
@
property
def
value
(
self
):
return
self
.
__value
@
property
def
has_value
(
self
):
return
self
.
__value
is
not
None
def
sparse_size
(
self
,
dim
=
None
):
return
self
.
__sparse_size
if
dim
is
None
else
self
.
__sparse_size
[
dim
]
def
size
(
self
,
dim
=
None
):
size
=
self
.
__sparse_size
size
+=
()
if
self
.
has_value
is
None
else
self
.
__value
.
size
()[
1
:]
return
size
if
dim
is
None
else
size
[
dim
]
@
property
def
shape
(
self
):
return
self
.
size
()
def
sparse_resize_
(
self
,
*
sizes
):
assert
len
(
sizes
)
==
2
self
.
__sparse_size
==
sizes
def
clone
(
self
):
raise
NotImplementedError
def
copy_
(
self
):
raise
NotImplementedError
def
pin_memory
(
self
):
raise
NotImplementedError
def
is_pinned
(
self
):
raise
NotImplementedError
def
share_memory_
(
self
):
raise
NotImplementedError
def
is_shared
(
self
):
raise
NotImplementedError
@
property
def
device
(
self
):
return
self
.
__row
.
device
def
cpu
(
self
):
pass
def
cuda
(
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
pass
@
property
def
is_cuda
(
self
):
pass
@
property
def
dtype
(
self
):
pass
def
type
(
dtype
=
None
,
non_blocking
=
False
,
**
kwargs
):
pass
def
is_floating_point
(
self
):
pass
def
bfloat16
(
self
):
pass
def
bool
(
self
):
pass
def
byte
(
self
):
pass
def
char
(
self
):
pass
def
half
(
self
):
pass
def
float
(
self
):
pass
def
double
(
self
):
pass
def
short
(
self
):
pass
def
int
(
self
):
pass
def
long
(
self
):
pass
def
__apply_index
(
self
,
func
):
pass
def
__apply_index_
(
self
,
func
):
self
.
__row
=
func
(
self
.
__row
)
self
.
__col
=
func
(
self
.
__col
)
self
.
__rowptr
=
func
(
self
.
__rowptr
)
self
.
__colptr
=
func
(
self
.
__colptr
)
self
.
__arg_csr_to_csc
=
func
(
self
.
__arg_csr_to_csc
)
self
.
__arg_csc_to_csr
=
func
(
self
.
__arg_csc_to_csr
)
def
__apply_value
(
self
,
func
):
pass
def
__apply_value_
(
self
,
func
):
self
.
__value
=
func
(
self
.
__value
)
if
self
.
has_value
else
None
def
__apply
(
self
,
func
):
pass
def
__apply_
(
self
,
func
):
self
.
__apply_index_
(
func
)
self
.
__apply_value_
(
func
)
if
__name__
==
'__main__'
:
from
torch_geometric.datasets
import
Reddit
# noqa
import
time
# noqa
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
dataset
=
Reddit
(
'/tmp/Reddit'
)
data
=
dataset
[
0
].
to
(
device
)
edge_index
=
data
.
edge_index
row
,
col
=
edge_index
print
(
row
.
size
())
print
(
row
[:
20
])
print
(
col
[:
20
])
print
(
'--------'
)
# storage = SparseStorage(row, col)
idx
=
data
.
num_nodes
*
col
+
row
perm
=
idx
.
argsort
()
row
,
col
=
row
[
perm
],
col
[
perm
]
print
(
row
[:
20
])
print
(
col
[:
20
])
print
(
'--------'
)
perm
=
perm
.
argsort
()
row
,
col
=
row
[
perm
],
col
[
perm
]
print
(
row
[:
20
])
print
(
col
[:
20
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment