Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
f3469f1a
"vscode:/vscode.git/clone" did not exist on "66a67d2efbab894196a733426b05f2b08da6fd79"
Commit
f3469f1a
authored
Dec 22, 2019
by
rusty1s
Browse files
test storage
parent
f3b7fb50
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
141 additions
and
48 deletions
+141
-48
test/test_add.py
test/test_add.py
+0
-1
test/test_storage.py
test/test_storage.py
+133
-0
test/utils.py
test/utils.py
+2
-2
torch_sparse/storage.py
torch_sparse/storage.py
+6
-45
No files found.
test/test_add.py
View file @
f3469f1a
...
@@ -10,7 +10,6 @@ from .utils import dtypes, devices, tensor
...
@@ -10,7 +10,6 @@ from .utils import dtypes, devices, tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_sparse_add
(
dtype
,
device
):
def
test_sparse_add
(
dtype
,
device
):
print
()
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
index
=
tensor
([[
0
,
0
,
1
],
[
0
,
1
,
2
]],
torch
.
long
,
device
)
mat1
=
SparseTensor
(
index
)
mat1
=
SparseTensor
(
index
)
...
...
test/test_storage.py
0 → 100644
View file @
f3469f1a
import
copy
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse.storage
import
SparseStorage
from
.utils
import
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_storage
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
index
)
assert
storage
.
index
.
tolist
()
==
index
.
tolist
()
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
is
None
assert
storage
.
sparse_size
()
==
(
2
,
2
)
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
value
=
tensor
([
2
,
1
,
4
,
3
],
dtype
,
device
)
storage
=
SparseStorage
(
index
,
value
)
assert
storage
.
index
.
tolist
()
==
[[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]]
assert
storage
.
row
.
tolist
()
==
[
0
,
0
,
1
,
1
]
assert
storage
.
col
.
tolist
()
==
[
0
,
1
,
0
,
1
]
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
assert
storage
.
sparse_size
()
==
(
2
,
2
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_caching
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
0
,
1
,
0
,
1
]],
torch
.
long
,
device
)
storage
=
SparseStorage
(
index
)
assert
storage
.
_index
.
tolist
()
==
index
.
tolist
()
assert
storage
.
_value
is
None
assert
storage
.
_rowcount
is
None
assert
storage
.
_rowptr
is
None
assert
storage
.
_colcount
is
None
assert
storage
.
_colptr
is
None
assert
storage
.
_csr2csc
is
None
assert
storage
.
cached_keys
()
==
[]
storage
.
fill_cache_
()
assert
storage
.
_rowcount
.
tolist
()
==
[
2
,
2
]
assert
storage
.
_rowptr
.
tolist
()
==
[
0
,
2
,
4
]
assert
storage
.
_colcount
.
tolist
()
==
[
2
,
2
]
assert
storage
.
_colptr
.
tolist
()
==
[
0
,
2
,
4
]
assert
storage
.
_csr2csc
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
_csc2csr
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
cached_keys
()
==
[
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
]
storage
=
SparseStorage
(
index
,
storage
.
value
,
storage
.
sparse_size
(),
storage
.
rowcount
,
storage
.
rowptr
,
storage
.
colcount
,
storage
.
colptr
,
storage
.
csr2csc
,
storage
.
csc2csr
)
assert
storage
.
_rowcount
.
tolist
()
==
[
2
,
2
]
assert
storage
.
_rowptr
.
tolist
()
==
[
0
,
2
,
4
]
assert
storage
.
_colcount
.
tolist
()
==
[
2
,
2
]
assert
storage
.
_colptr
.
tolist
()
==
[
0
,
2
,
4
]
assert
storage
.
_csr2csc
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
_csc2csr
.
tolist
()
==
[
0
,
2
,
1
,
3
]
assert
storage
.
cached_keys
()
==
[
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
]
storage
.
clear_cache_
()
assert
storage
.
_rowcount
is
None
assert
storage
.
_rowptr
is
None
assert
storage
.
_colcount
is
None
assert
storage
.
_colptr
is
None
assert
storage
.
_csr2csc
is
None
assert
storage
.
cached_keys
()
==
[]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_utility
(
dtype
,
device
):
index
=
tensor
([[
0
,
0
,
1
,
1
],
[
1
,
0
,
1
,
0
]],
torch
.
long
,
device
)
value
=
tensor
([
1
,
2
,
3
,
4
],
dtype
,
device
)
storage
=
SparseStorage
(
index
,
value
)
assert
storage
.
has_value
()
storage
.
set_value_
(
value
,
layout
=
'csc'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
3
,
2
,
4
]
storage
.
set_value_
(
value
,
layout
=
'coo'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
storage
=
storage
.
set_value
(
value
,
layout
=
'csc'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
3
,
2
,
4
]
storage
=
storage
.
set_value
(
value
,
layout
=
'coo'
)
assert
storage
.
value
.
tolist
()
==
[
1
,
2
,
3
,
4
]
storage
.
sparse_resize_
(
3
,
3
)
assert
storage
.
sparse_size
()
==
(
3
,
3
)
new_storage
=
copy
.
copy
(
storage
)
assert
new_storage
!=
storage
assert
new_storage
.
index
.
data_ptr
()
==
storage
.
index
.
data_ptr
()
new_storage
=
storage
.
clone
()
assert
new_storage
!=
storage
assert
new_storage
.
index
.
data_ptr
()
!=
storage
.
index
.
data_ptr
()
new_storage
=
copy
.
deepcopy
(
storage
)
assert
new_storage
!=
storage
assert
new_storage
.
index
.
data_ptr
()
!=
storage
.
index
.
data_ptr
()
storage
.
apply_value_
(
lambda
x
:
x
+
1
)
assert
storage
.
value
.
tolist
()
==
[
2
,
3
,
4
,
5
]
storage
=
storage
.
apply_value
(
lambda
x
:
x
+
1
)
assert
storage
.
value
.
tolist
()
==
[
3
,
4
,
5
,
6
]
storage
.
apply_
(
lambda
x
:
x
.
to
(
torch
.
long
))
assert
storage
.
index
.
dtype
==
torch
.
long
assert
storage
.
value
.
dtype
==
torch
.
long
storage
=
storage
.
apply
(
lambda
x
:
x
.
to
(
torch
.
long
))
assert
storage
.
index
.
dtype
==
torch
.
long
assert
storage
.
value
.
dtype
==
torch
.
long
storage
.
clear_cache_
()
assert
storage
.
map
(
lambda
x
:
x
.
numel
())
==
[
8
,
4
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_coalesce
(
dtype
,
device
):
pass
test/utils.py
View file @
f3469f1a
...
@@ -3,8 +3,8 @@ import torch
...
@@ -3,8 +3,8 @@ import torch
dtypes
=
[
torch
.
float
]
dtypes
=
[
torch
.
float
]
devices
=
[
torch
.
device
(
'cpu'
)]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
#
if torch.cuda.is_available():
devices
+=
[
torch
.
device
(
'cuda:{}'
.
format
(
torch
.
cuda
.
current_device
()))]
#
devices += [torch.device('cuda:{}'.format(torch.cuda.current_device()))]
def
tensor
(
x
,
dtype
,
device
):
def
tensor
(
x
,
dtype
,
device
):
...
...
torch_sparse/storage.py
View file @
f3469f1a
...
@@ -38,17 +38,9 @@ class SparseStorage(object):
...
@@ -38,17 +38,9 @@ class SparseStorage(object):
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
'rowcount'
,
'rowptr'
,
'colcount'
,
'colptr'
,
'csr2csc'
,
'csc2csr'
]
]
def
__init__
(
self
,
def
__init__
(
self
,
index
,
value
=
None
,
sparse_size
=
None
,
rowcount
=
None
,
index
,
rowptr
=
None
,
colcount
=
None
,
colptr
=
None
,
csr2csc
=
None
,
value
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
sparse_size
=
None
,
rowcount
=
None
,
rowptr
=
None
,
colcount
=
None
,
colptr
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
False
):
assert
index
.
dtype
==
torch
.
long
assert
index
.
dtype
==
torch
.
long
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
assert
index
.
dim
()
==
2
and
index
.
size
(
0
)
==
2
...
@@ -97,7 +89,7 @@ class SparseStorage(object):
...
@@ -97,7 +89,7 @@ class SparseStorage(object):
if
not
is_sorted
:
if
not
is_sorted
:
idx
=
sparse_size
[
1
]
*
index
[
0
]
+
index
[
1
]
idx
=
sparse_size
[
1
]
*
index
[
0
]
+
index
[
1
]
# Only sort if necessary...
# Only sort if necessary...
if
(
idx
<
=
torch
.
cat
([
idx
.
new_zeros
(
1
),
idx
[:
-
1
]],
dim
=
0
)).
any
():
if
(
idx
<
torch
.
cat
([
idx
.
new_zeros
(
1
),
idx
[:
-
1
]],
dim
=
0
)).
any
():
perm
=
idx
.
argsort
()
perm
=
idx
.
argsort
()
index
=
index
[:,
perm
]
index
=
index
[:,
perm
]
value
=
None
if
value
is
None
else
value
[
perm
]
value
=
None
if
value
is
None
else
value
[
perm
]
...
@@ -164,7 +156,7 @@ class SparseStorage(object):
...
@@ -164,7 +156,7 @@ class SparseStorage(object):
def
sparse_resize_
(
self
,
*
sizes
):
def
sparse_resize_
(
self
,
*
sizes
):
assert
len
(
sizes
)
==
2
assert
len
(
sizes
)
==
2
self
.
_sparse_size
=
=
sizes
self
.
_sparse_size
=
sizes
return
self
return
self
@
cached_property
@
cached_property
...
@@ -269,7 +261,7 @@ class SparseStorage(object):
...
@@ -269,7 +261,7 @@ class SparseStorage(object):
self
.
_index
=
func
(
self
.
_index
)
self
.
_index
=
func
(
self
.
_index
)
self
.
_value
=
optional
(
func
,
self
.
_value
)
self
.
_value
=
optional
(
func
,
self
.
_value
)
for
key
in
self
.
cached_keys
():
for
key
in
self
.
cached_keys
():
setattr
(
self
,
f
'_
{
key
}
'
,
func
,
getattr
(
self
,
f
'_
{
key
}
'
))
setattr
(
self
,
f
'_
{
key
}
'
,
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
)
return
self
return
self
def
apply
(
self
,
func
):
def
apply
(
self
,
func
):
...
@@ -292,34 +284,3 @@ class SparseStorage(object):
...
@@ -292,34 +284,3 @@ class SparseStorage(object):
data
+=
[
func
(
self
.
value
)]
data
+=
[
func
(
self
.
value
)]
data
+=
[
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cached_keys
()]
data
+=
[
func
(
getattr
(
self
,
f
'_
{
key
}
'
))
for
key
in
self
.
cached_keys
()]
return
data
return
data
if
__name__
==
'__main__'
:
from
torch_geometric.datasets
import
Reddit
,
Planetoid
# noqa
import
time
# noqa
import
copy
# noqa
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
# dataset = Reddit('/tmp/Reddit')
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
data
=
dataset
[
0
].
to
(
device
)
edge_index
=
data
.
edge_index
storage
=
SparseStorage
(
edge_index
,
is_sorted
=
True
)
t
=
time
.
perf_counter
()
storage
.
fill_cache_
()
print
(
time
.
perf_counter
()
-
t
)
t
=
time
.
perf_counter
()
storage
.
clear_cache_
()
storage
.
fill_cache_
()
print
(
time
.
perf_counter
()
-
t
)
print
(
storage
)
# storage = storage.clone()
# print(storage)
storage
=
copy
.
copy
(
storage
)
print
(
storage
)
print
(
id
(
storage
))
storage
=
copy
.
deepcopy
(
storage
)
print
(
storage
)
storage
.
fill_cache_
()
storage
.
clear_cache_
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment