Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
b5624cb8
Commit
b5624cb8
authored
Dec 18, 2019
by
rusty1s
Browse files
test
parent
9971227c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
2 deletions
+28
-2
torch_sparse/storage.py
torch_sparse/storage.py
+28
-2
No files found.
torch_sparse/storage.py
View file @
b5624cb8
...
@@ -163,6 +163,20 @@ class SparseStorage(object):
...
@@ -163,6 +163,20 @@ class SparseStorage(object):
setattr
(
self
,
f
'_
{
arg
}
'
,
None
)
setattr
(
self
,
f
'_
{
arg
}
'
,
None
)
return
self
return
self
def
clone
(
self
):
return
self
.
apply
(
lambda
x
:
x
.
clone
())
def
__copy__
(
self
):
return
self
.
clone
()
def
__deepcopy__
(
self
,
memo
):
memo
=
memo
.
setdefault
(
'SparseStorage'
,
{})
if
self
.
_cdata
in
memo
:
return
memo
[
self
.
_cdata
]
new_storage
=
self
.
clone
()
memo
[
self
.
_cdata
]
=
new_storage
return
new_storage
def
apply_value_
(
self
,
func
):
def
apply_value_
(
self
,
func
):
self
.
_value
=
optional
(
func
,
self
.
_value
)
self
.
_value
=
optional
(
func
,
self
.
_value
)
return
self
return
self
...
@@ -199,11 +213,13 @@ class SparseStorage(object):
...
@@ -199,11 +213,13 @@ class SparseStorage(object):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
from
torch_geometric.datasets
import
Reddit
# noqa
from
torch_geometric.datasets
import
Reddit
,
Planetoid
# noqa
import
time
# noqa
import
time
# noqa
import
copy
# noqa
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
dataset
=
Reddit
(
'/tmp/Reddit'
)
# dataset = Reddit('/tmp/Reddit')
dataset
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)
data
=
dataset
[
0
].
to
(
device
)
data
=
dataset
[
0
].
to
(
device
)
edge_index
=
data
.
edge_index
edge_index
=
data
.
edge_index
...
@@ -212,5 +228,15 @@ if __name__ == '__main__':
...
@@ -212,5 +228,15 @@ if __name__ == '__main__':
storage
.
compute_cache_
()
storage
.
compute_cache_
()
print
(
time
.
perf_counter
()
-
t
)
print
(
time
.
perf_counter
()
-
t
)
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
storage
.
clear_cache_
()
storage
.
compute_cache_
()
storage
.
compute_cache_
()
print
(
time
.
perf_counter
()
-
t
)
print
(
time
.
perf_counter
()
-
t
)
print
(
storage
)
storage
=
storage
.
clone
()
print
(
storage
)
# storage = copy.copy(storage)
# print(storage)
# storage = copy.deepcopy(storage)
# print(storage)
storage
.
compute_cache_
()
storage
.
clear_cache_
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment