Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
b38d0b5e
Commit
b38d0b5e
authored
Mar 12, 2022
by
rusty1s
Browse files
fix test
parent
c4f318ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
48 deletions
+107
-48
test/test_metis.py
test/test_metis.py
+1
-1
torch_sparse/storage.py
torch_sparse/storage.py
+106
-47
No files found.
test/test_metis.py
View file @
b38d0b5e
...
@@ -9,7 +9,7 @@ from .utils import devices
...
@@ -9,7 +9,7 @@ from .utils import devices
try
:
try
:
rowptr
=
torch
.
tensor
([
0
,
1
])
rowptr
=
torch
.
tensor
([
0
,
1
])
col
=
torch
.
tensor
([
0
])
col
=
torch
.
tensor
([
0
])
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
None
,
1
)
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
None
,
1
,
True
)
with_metis
=
True
with_metis
=
True
except
RuntimeError
:
except
RuntimeError
:
with_metis
=
False
with_metis
=
False
...
...
torch_sparse/storage.py
View file @
b38d0b5e
...
@@ -30,19 +30,21 @@ class SparseStorage(object):
...
@@ -30,19 +30,21 @@ class SparseStorage(object):
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
def
__init__
(
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
self
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
col
:
Optional
[
torch
.
Tensor
]
=
None
,
Optional
[
int
]]]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
,
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
,
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
,
trust_data
:
bool
=
False
):
is_sorted
:
bool
=
False
,
trust_data
:
bool
=
False
,
):
assert
row
is
not
None
or
rowptr
is
not
None
assert
row
is
not
None
or
rowptr
is
not
None
assert
col
is
not
None
assert
col
is
not
None
...
@@ -240,7 +242,8 @@ class SparseStorage(object):
...
@@ -240,7 +242,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
_sparse_sizes
return
self
.
_sparse_sizes
...
@@ -290,7 +293,8 @@ class SparseStorage(object):
...
@@ -290,7 +293,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
def
sparse_reshape
(
self
,
num_rows
:
int
,
num_cols
:
int
):
def
sparse_reshape
(
self
,
num_rows
:
int
,
num_cols
:
int
):
assert
num_rows
>
0
or
num_rows
==
-
1
assert
num_rows
>
0
or
num_rows
==
-
1
...
@@ -313,10 +317,20 @@ class SparseStorage(object):
...
@@ -313,10 +317,20 @@ class SparseStorage(object):
col
=
idx
%
num_cols
col
=
idx
%
num_cols
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
self
.
_value
,
return
SparseStorage
(
sparse_sizes
=
(
num_rows
,
num_cols
),
rowcount
=
None
,
row
=
row
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
rowptr
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
)
col
=
col
,
value
=
self
.
_value
,
sparse_sizes
=
(
num_rows
,
num_cols
),
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
has_rowcount
(
self
)
->
bool
:
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
return
self
.
_rowcount
is
not
None
...
@@ -413,10 +427,20 @@ class SparseStorage(object):
...
@@ -413,10 +427,20 @@ class SparseStorage(object):
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
row
=
row
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
rowptr
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
fill_cache_
(
self
):
def
fill_cache_
(
self
):
self
.
row
()
self
.
row
()
...
@@ -466,7 +490,8 @@ class SparseStorage(object):
...
@@ -466,7 +490,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
,
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
def
clone
(
self
):
def
clone
(
self
):
row
=
self
.
_row
row
=
self
.
_row
...
@@ -495,11 +520,20 @@ class SparseStorage(object):
...
@@ -495,11 +520,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
type
(
self
,
dtype
:
torch
.
dtype
,
non_blocking
:
bool
=
False
):
def
type
(
self
,
dtype
:
torch
.
dtype
,
non_blocking
:
bool
=
False
):
value
=
self
.
_value
value
=
self
.
_value
...
@@ -508,9 +542,7 @@ class SparseStorage(object):
...
@@ -508,9 +542,7 @@ class SparseStorage(object):
return
self
return
self
else
:
else
:
return
self
.
set_value
(
return
self
.
set_value
(
value
.
to
(
value
.
to
(
dtype
=
dtype
,
non_blocking
=
non_blocking
),
dtype
=
dtype
,
non_blocking
=
non_blocking
),
layout
=
'coo'
)
layout
=
'coo'
)
else
:
else
:
return
self
return
self
...
@@ -548,11 +580,20 @@ class SparseStorage(object):
...
@@ -548,11 +580,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
device
,
non_blocking
=
non_blocking
)
csc2csr
=
csc2csr
.
to
(
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
return
self
.
to_device
(
device
=
tensor
.
device
,
non_blocking
=
non_blocking
)
return
self
.
to_device
(
device
=
tensor
.
device
,
non_blocking
=
non_blocking
)
...
@@ -587,11 +628,20 @@ class SparseStorage(object):
...
@@ -587,11 +628,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
cuda
()
csc2csr
=
csc2csr
.
cuda
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
new_col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
new_col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
pin_memory
(
self
):
def
pin_memory
(
self
):
row
=
self
.
_row
row
=
self
.
_row
...
@@ -620,11 +670,20 @@ class SparseStorage(object):
...
@@ -620,11 +670,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
return
SparseStorage
(
sparse_sizes
=
self
.
_sparse_sizes
,
row
=
row
,
rowcount
=
rowcount
,
colptr
=
colptr
,
rowptr
=
rowptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
col
=
col
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
is_pinned
(
self
)
->
bool
:
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
is_pinned
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment