Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
b38d0b5e
Commit
b38d0b5e
authored
Mar 12, 2022
by
rusty1s
Browse files
fix test
parent
c4f318ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
48 deletions
+107
-48
test/test_metis.py
test/test_metis.py
+1
-1
torch_sparse/storage.py
torch_sparse/storage.py
+106
-47
No files found.
test/test_metis.py
View file @
b38d0b5e
...
...
@@ -9,7 +9,7 @@ from .utils import devices
try
:
rowptr
=
torch
.
tensor
([
0
,
1
])
col
=
torch
.
tensor
([
0
])
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
None
,
1
)
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
None
,
1
,
True
)
with_metis
=
True
except
RuntimeError
:
with_metis
=
False
...
...
torch_sparse/storage.py
View file @
b38d0b5e
...
...
@@ -30,19 +30,21 @@ class SparseStorage(object):
_csr2csc
:
Optional
[
torch
.
Tensor
]
_csc2csr
:
Optional
[
torch
.
Tensor
]
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
,
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
,
trust_data
:
bool
=
False
):
def
__init__
(
self
,
row
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
csr2csc
:
Optional
[
torch
.
Tensor
]
=
None
,
csc2csr
:
Optional
[
torch
.
Tensor
]
=
None
,
is_sorted
:
bool
=
False
,
trust_data
:
bool
=
False
,
):
assert
row
is
not
None
or
rowptr
is
not
None
assert
col
is
not
None
...
...
@@ -240,7 +242,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
def
sparse_sizes
(
self
)
->
Tuple
[
int
,
int
]:
return
self
.
_sparse_sizes
...
...
@@ -290,7 +293,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
def
sparse_reshape
(
self
,
num_rows
:
int
,
num_cols
:
int
):
assert
num_rows
>
0
or
num_rows
==
-
1
...
...
@@ -313,10 +317,20 @@ class SparseStorage(object):
col
=
idx
%
num_cols
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
self
.
_value
,
sparse_sizes
=
(
num_rows
,
num_cols
),
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
self
.
_value
,
sparse_sizes
=
(
num_rows
,
num_cols
),
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
has_rowcount
(
self
)
->
bool
:
return
self
.
_rowcount
is
not
None
...
...
@@ -413,10 +427,20 @@ class SparseStorage(object):
ptr
=
torch
.
cat
([
ptr
,
ptr
.
new_full
((
1
,
),
value
.
size
(
0
))])
value
=
segment_csr
(
value
,
ptr
,
reduce
=
reduce
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
None
,
colptr
=
None
,
colcount
=
None
,
csr2csc
=
None
,
csc2csr
=
None
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
fill_cache_
(
self
):
self
.
row
()
...
...
@@ -466,7 +490,8 @@ class SparseStorage(object):
csr2csc
=
self
.
_csr2csc
,
csc2csr
=
self
.
_csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
trust_data
=
True
,
)
def
clone
(
self
):
row
=
self
.
_row
...
...
@@ -495,11 +520,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
clone
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
type
(
self
,
dtype
:
torch
.
dtype
,
non_blocking
:
bool
=
False
):
value
=
self
.
_value
...
...
@@ -508,9 +542,7 @@ class SparseStorage(object):
return
self
else
:
return
self
.
set_value
(
value
.
to
(
dtype
=
dtype
,
non_blocking
=
non_blocking
),
value
.
to
(
dtype
=
dtype
,
non_blocking
=
non_blocking
),
layout
=
'coo'
)
else
:
return
self
...
...
@@ -548,11 +580,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
to
(
device
,
non_blocking
=
non_blocking
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
device_as
(
self
,
tensor
:
torch
.
Tensor
,
non_blocking
:
bool
=
False
):
return
self
.
to_device
(
device
=
tensor
.
device
,
non_blocking
=
non_blocking
)
...
...
@@ -587,11 +628,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
cuda
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
new_col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
new_col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
pin_memory
(
self
):
row
=
self
.
_row
...
...
@@ -620,11 +670,20 @@ class SparseStorage(object):
if
csc2csr
is
not
None
:
csc2csr
=
csc2csr
.
pin_memory
()
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
)
return
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
self
.
_sparse_sizes
,
rowcount
=
rowcount
,
colptr
=
colptr
,
colcount
=
colcount
,
csr2csc
=
csr2csc
,
csc2csr
=
csc2csr
,
is_sorted
=
True
,
trust_data
=
True
,
)
def
is_pinned
(
self
)
->
bool
:
is_pinned
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment