Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
28f12953
"tests/vscode:/vscode.git/clone" did not exist on "405de769b7faabb19b00a176e0b6dca8a0df3581"
Unverified
Commit
28f12953
authored
Oct 18, 2021
by
Matthias Fey
Committed by
GitHub
Oct 18, 2021
Browse files
Merge pull request #176 from shi27feng/patch-1
Update storage.py
parents
23709f94
9b5d3c79
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
17 deletions
+27
-17
torch_sparse/storage.py
torch_sparse/storage.py
+23
-15
torch_sparse/tensor.py
torch_sparse/tensor.py
+4
-2
No files found.
torch_sparse/storage.py
View file @
28f12953
...
@@ -34,7 +34,8 @@ class SparseStorage(object):
...
@@ -34,7 +34,8 @@ class SparseStorage(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
rowcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colptr
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
colcount
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -48,26 +49,33 @@ class SparseStorage(object):
...
@@ -48,26 +49,33 @@ class SparseStorage(object):
assert
col
.
dim
()
==
1
assert
col
.
dim
()
==
1
col
=
col
.
contiguous
()
col
=
col
.
contiguous
()
if
sparse_sizes
is
None
:
M
:
int
=
0
if
sparse_sizes
is
None
or
sparse_sizes
[
0
]
is
None
:
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
M
=
rowptr
.
numel
()
-
1
M
=
rowptr
.
numel
()
-
1
elif
row
is
not
None
and
row
.
numel
()
>
0
:
elif
row
is
not
None
and
row
.
numel
()
>
0
:
M
=
row
.
max
().
item
()
+
1
M
=
int
(
row
.
max
())
+
1
elif
row
is
not
None
and
row
.
numel
()
==
0
:
M
=
0
else
:
else
:
raise
ValueError
_M
=
sparse_sizes
[
0
]
assert
_M
is
not
None
M
=
_M
if
rowptr
is
not
None
:
assert
rowptr
.
numel
()
-
1
==
M
elif
row
is
not
None
and
row
.
numel
()
>
0
:
assert
int
(
row
.
max
())
<
M
N
:
int
=
0
if
sparse_sizes
is
None
or
sparse_sizes
[
1
]
is
None
:
if
col
.
numel
()
>
0
:
if
col
.
numel
()
>
0
:
N
=
col
.
max
()
.
item
(
)
+
1
N
=
int
(
col
.
max
())
+
1
else
:
else
:
N
=
0
_N
=
sparse_sizes
[
1
]
sparse_sizes
=
(
int
(
M
),
int
(
N
))
assert
_N
is
not
None
else
:
N
=
_N
assert
len
(
sparse_sizes
)
==
2
if
row
is
not
None
and
row
.
numel
()
>
0
:
assert
row
.
max
().
item
()
<
sparse_sizes
[
0
]
if
col
.
numel
()
>
0
:
if
col
.
numel
()
>
0
:
assert
col
.
max
().
item
()
<
sparse_sizes
[
1
]
assert
int
(
col
.
max
())
<
N
sparse_sizes
=
(
M
,
N
)
if
row
is
not
None
:
if
row
is
not
None
:
assert
row
.
dtype
==
torch
.
long
assert
row
.
dtype
==
torch
.
long
...
...
torch_sparse/tensor.py
View file @
28f12953
...
@@ -16,7 +16,8 @@ class SparseTensor(object):
...
@@ -16,7 +16,8 @@ class SparseTensor(object):
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
rowptr
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
col
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
value
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
is_sorted
:
bool
=
False
):
is_sorted
:
bool
=
False
):
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
self
.
storage
=
SparseStorage
(
row
=
row
,
rowptr
=
rowptr
,
col
=
col
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
value
=
value
,
sparse_sizes
=
sparse_sizes
,
...
@@ -39,7 +40,8 @@ class SparseTensor(object):
...
@@ -39,7 +40,8 @@ class SparseTensor(object):
@
classmethod
@
classmethod
def
from_edge_index
(
self
,
edge_index
:
torch
.
Tensor
,
def
from_edge_index
(
self
,
edge_index
:
torch
.
Tensor
,
edge_attr
:
Optional
[
torch
.
Tensor
]
=
None
,
edge_attr
:
Optional
[
torch
.
Tensor
]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
sparse_sizes
:
Optional
[
Tuple
[
Optional
[
int
],
Optional
[
int
]]]
=
None
,
is_sorted
:
bool
=
False
):
is_sorted
:
bool
=
False
):
return
SparseTensor
(
row
=
edge_index
[
0
],
rowptr
=
None
,
col
=
edge_index
[
1
],
return
SparseTensor
(
row
=
edge_index
[
0
],
rowptr
=
None
,
col
=
edge_index
[
1
],
value
=
edge_attr
,
sparse_sizes
=
sparse_sizes
,
value
=
edge_attr
,
sparse_sizes
=
sparse_sizes
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment