Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
a1f64207
Commit
a1f64207
authored
Dec 15, 2019
by
rusty1s
Browse files
storage functionality
parent
b3746aab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
60 deletions
+90
-60
torch_sparse/storage.py
torch_sparse/storage.py
+90
-60
No files found.
torch_sparse/storage.py
View file @
a1f64207
import
inspect
import
torch
from
torch
import
Size
from
torch_scatter
import
scatter_add
,
segment_add
...
...
@@ -37,18 +39,34 @@ class SparseStorage(object):
ones
=
torch
.
ones_like
(
row
)
out_deg
=
segment_add
(
ones
,
row
,
dim
=
0
,
dim_size
=
sparse_size
[
0
])
rowptr
=
torch
.
cat
([
row
.
new_zeros
(
1
),
out_deg
.
cumsum
(
0
)],
dim
=
0
)
else
:
assert
rowptr
.
dtype
==
torch
.
long
and
rowptr
.
device
==
row
.
device
assert
rowptr
.
dim
()
==
1
and
rowptr
.
size
(
0
)
==
sparse_size
[
0
]
-
1
if
colptr
is
None
:
ones
=
torch
.
ones_like
(
col
)
if
ones
is
None
else
ones
in_deg
=
scatter_add
(
ones
,
col
,
dim
=
0
,
dim_size
=
sparse_size
[
1
])
colptr
=
torch
.
cat
([
col
.
new_zeros
(
1
),
in_deg
.
cumsum
(
0
)],
dim
=
0
)
else
:
assert
colptr
.
dtype
==
torch
.
long
and
colptr
.
device
==
col
.
device
assert
colptr
.
dim
()
==
1
and
colptr
.
size
(
0
)
==
sparse_size
[
1
]
-
1
if
arg_csr_to_csc
is
None
:
idx
=
sparse_size
[
0
]
*
col
+
row
arg_csr_to_csc
=
idx
.
argsort
()
else
:
assert
arg_csr_to_csc
==
torch
.
long
assert
arg_csr_to_csc
.
device
==
row
.
device
assert
arg_csr_to_csc
.
dim
()
==
1
assert
arg_csr_to_csc
.
size
(
0
)
==
row
.
size
(
0
)
if
arg_cs
r
_to_cs
c
is
None
:
if
arg_cs
c
_to_cs
r
is
None
:
arg_csc_to_csr
=
arg_csr_to_csc
.
argsort
()
else
:
assert
arg_csc_to_csr
==
torch
.
long
assert
arg_csc_to_csr
.
device
==
row
.
device
assert
arg_csc_to_csr
.
dim
()
==
1
assert
arg_csc_to_csr
.
size
(
0
)
==
row
.
size
(
0
)
self
.
__row
=
row
self
.
__col
=
col
...
...
@@ -60,34 +78,34 @@ class SparseStorage(object):
self
.
__arg_csc_to_csr
=
arg_csc_to_csr
@
property
def
row
(
self
):
def
_
row
(
self
):
return
self
.
__row
@
property
def
col
(
self
):
def
_
col
(
self
):
return
self
.
__col
def
index
(
self
):
def
_
index
(
self
):
return
torch
.
stack
([
self
.
__row
,
self
.
__col
],
dim
=
0
)
@
property
def
rowptr
(
self
):
def
_
rowptr
(
self
):
return
self
.
__rowptr
@
property
def
colptr
(
self
):
def
_
colptr
(
self
):
return
self
.
__colptr
@
property
def
arg_csr_to_csc
(
self
):
def
_
arg_csr_to_csc
(
self
):
return
self
.
__arg_csr_to_csc
@
property
def
arg_csc_to_csr
(
self
):
def
_
arg_csc_to_csr
(
self
):
return
self
.
__arg_csc_to_csr
@
property
def
value
(
self
):
def
_
value
(
self
):
return
self
.
__value
@
property
...
...
@@ -99,7 +117,7 @@ class SparseStorage(object):
def
size
(
self
,
dim
=
None
):
size
=
self
.
__sparse_size
size
+=
()
if
self
.
has
_value
is
None
else
self
.
__value
.
size
()[
1
:]
size
+=
()
if
self
.
_
_value
is
None
else
self
.
__value
.
size
()[
1
:]
return
size
if
dim
is
None
else
size
[
dim
]
@
property
...
...
@@ -109,102 +127,125 @@ class SparseStorage(object):
def
sparse_resize_
(
self
,
*
sizes
):
assert
len
(
sizes
)
==
2
self
.
__sparse_size
==
sizes
return
self
def
clone
(
self
):
r
aise
NotImplementedError
r
eturn
self
.
__apply
(
lambda
x
:
x
.
clone
())
def
copy_
(
self
):
r
aise
NotImplementedError
def
__
copy_
_
(
self
):
r
eturn
self
.
clone
()
def
pin_memory
(
self
):
r
aise
NotImplementedError
r
eturn
self
.
__apply
(
lambda
x
:
x
.
pin_memory
())
def
is_pinned
(
self
):
r
aise
NotImplementedError
r
eturn
all
([
x
.
is_pinned
for
x
in
self
.
__attributes
])
def
share_memory_
(
self
):
r
aise
NotImplementedError
r
eturn
self
.
__apply_
(
lambda
x
:
x
.
share_memory_
())
def
is_shared
(
self
):
r
aise
NotImplementedError
r
eturn
all
([
x
.
is_shared
for
x
in
self
.
__attributes
])
@
property
def
device
(
self
):
return
self
.
__row
.
device
def
cpu
(
self
):
pass
return
self
.
__apply
(
lambda
x
:
x
.
cpu
())
def
cuda
(
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
pass
def
cuda
(
self
,
device
=
None
,
non_blocking
=
False
,
**
kwargs
):
return
self
.
__apply
(
lambda
x
:
x
.
cuda
(
device
,
non_blocking
,
**
kwargs
))
@
property
def
is_cuda
(
self
):
pass
return
self
.
__row
.
is_cuda
@
property
def
dtype
(
self
):
pass
return
None
if
self
.
__value
is
None
else
self
.
__value
.
dtype
def
to
(
self
,
*
args
,
**
kwargs
):
if
'device'
in
kwargs
:
out
=
self
.
__apply
(
lambda
x
:
x
.
to
(
kwargs
[
'device'
]))
del
kwargs
[
'device'
]
for
arg
in
args
[:]:
if
isinstance
(
arg
,
str
)
or
isinstance
(
arg
,
torch
.
device
):
out
=
self
.
__apply
(
lambda
x
:
x
.
to
(
arg
))
args
.
remove
(
arg
)
def
type
(
dtype
=
None
,
non_blocking
=
False
,
**
kwargs
):
pass
if
len
(
args
)
>
0
and
len
(
kwargs
)
>
0
:
out
=
self
.
type
(
*
args
,
**
kwargs
)
return
out
def
type
(
self
,
dtype
=
None
,
non_blocking
=
False
,
**
kwargs
):
return
self
.
dtype
if
dtype
is
None
else
self
.
__apply_value
(
lambda
x
:
x
.
type
(
dtype
,
non_blocking
,
**
kwargs
))
def
is_floating_point
(
self
):
pass
return
self
.
__value
is
None
or
torch
.
is_floating_point
(
self
.
__value
)
def
bfloat16
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
bfloat16
())
def
bool
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
bool
())
def
byte
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
byte
())
def
char
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
char
())
def
half
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
half
())
def
float
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
float
())
def
double
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
double
())
def
short
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
short
())
def
int
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
int
())
def
long
(
self
):
pass
return
self
.
__apply_value
(
lambda
x
:
x
.
long
())
###########################################################################
def
__
apply_index
(
self
,
func
):
pass
def
__
keys
(
self
):
return
inspect
.
getfullargspec
(
self
.
__init__
)[
0
][
1
:
-
1
]
def
__apply_index_
(
self
,
func
):
self
.
__row
=
func
(
self
.
__row
)
self
.
__col
=
func
(
self
.
__col
)
self
.
__rowptr
=
func
(
self
.
__rowptr
)
self
.
__colptr
=
func
(
self
.
__colptr
)
self
.
__arg_csr_to_csc
=
func
(
self
.
__arg_csr_to_csc
)
self
.
__arg_csc_to_csr
=
func
(
self
.
__arg_csc_to_csr
)
def
__state
(
self
):
return
{
key
:
getattr
(
self
,
f
'_
{
self
.
__class__
.
__name__
}
__
{
key
}
'
)
for
key
in
self
.
__keys
()
}
def
__apply_value
(
self
,
func
):
pass
state
=
self
.
__state
()
state
[
'value'
]
==
func
(
self
.
__value
)
return
self
.
__class__
(
is_sorted
=
True
,
**
state
)
def
__apply_value_
(
self
,
func
):
self
.
__value
=
func
(
self
.
__value
)
if
self
.
has_value
else
None
self
.
__value
=
None
if
self
.
__value
is
None
else
func
(
self
.
__value
)
return
self
def
__apply
(
self
,
func
):
pass
state
=
{
key
:
func
(
item
)
for
key
,
item
in
self
.
__state
().
items
()}
return
self
.
__class__
(
is_sorted
=
True
,
**
state
)
def
__apply_
(
self
,
func
):
self
.
__apply_index_
(
func
)
self
.
__apply_value_
(
func
)
state
=
self
.
__state
()
del
state
[
'value'
]
for
key
,
item
in
self
.
__state
().
items
():
setattr
(
self
,
f
'_
{
self
.
__class__
.
__name__
}
__
{
key
}
'
,
func
(
item
))
return
self
.
__apply_value_
(
func
)
if
__name__
==
'__main__'
:
...
...
@@ -219,14 +260,3 @@ if __name__ == '__main__':
row
,
col
=
edge_index
storage
=
SparseStorage
(
row
,
col
)
# idx = data.num_nodes * col + row
# perm = idx.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
# print('--------')
# perm = perm.argsort()
# row, col = row[perm], col[perm]
# print(row[:20])
# print(col[:20])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment