Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
3679da2b
Commit
3679da2b
authored
Jul 17, 2020
by
rusty1s
Browse files
remove warnings filter
parent
8c7560b0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
11 deletions
+10
-11
torch_sparse/storage.py
torch_sparse/storage.py
+6
-7
torch_sparse/tensor.py
torch_sparse/tensor.py
+4
-4
No files found.
torch_sparse/storage.py
View file @
3679da2b
...
@@ -7,9 +7,6 @@ from torch_sparse.utils import Final
...
@@ -7,9 +7,6 @@ from torch_sparse.utils import Final
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
layouts
:
Final
[
List
[
str
]]
=
[
'coo'
,
'csr'
,
'csc'
]
# FIXME: Remove once `/` on `LongTensors` is officially removed from PyTorch.
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
)
def
get_layout
(
layout
:
Optional
[
str
]
=
None
)
->
str
:
def
get_layout
(
layout
:
Optional
[
str
]
=
None
)
->
str
:
if
layout
is
None
:
if
layout
is
None
:
...
@@ -130,7 +127,9 @@ class SparseStorage(object):
...
@@ -130,7 +127,9 @@ class SparseStorage(object):
if
not
is_sorted
:
if
not
is_sorted
:
idx
=
self
.
_col
.
new_zeros
(
self
.
_col
.
numel
()
+
1
)
idx
=
self
.
_col
.
new_zeros
(
self
.
_col
.
numel
()
+
1
)
idx
[
1
:]
=
self
.
_sparse_sizes
[
1
]
*
self
.
row
()
+
self
.
_col
idx
[
1
:]
=
self
.
row
()
idx
[
1
:]
*=
self
.
_sparse_sizes
[
1
]
idx
[
1
:]
+=
self
.
_col
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
if
(
idx
[
1
:]
<
idx
[:
-
1
]).
any
():
perm
=
idx
[
1
:].
argsort
()
perm
=
idx
[
1
:].
argsort
()
self
.
_row
=
self
.
row
()[
perm
]
self
.
_row
=
self
.
row
()[
perm
]
...
@@ -238,7 +237,7 @@ class SparseStorage(object):
...
@@ -238,7 +237,7 @@ class SparseStorage(object):
rowptr
=
torch
.
cat
([
rowptr
,
rowptr
.
new_full
((
diff_0
,
),
nnz
)])
rowptr
=
torch
.
cat
([
rowptr
,
rowptr
.
new_full
((
diff_0
,
),
nnz
)])
if
rowcount
is
not
None
:
if
rowcount
is
not
None
:
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
rowcount
=
torch
.
cat
([
rowcount
,
rowcount
.
new_zeros
(
diff_0
)])
el
se
:
el
if
diff_0
<
0
:
if
rowptr
is
not
None
:
if
rowptr
is
not
None
:
rowptr
=
rowptr
[:
-
diff_0
]
rowptr
=
rowptr
[:
-
diff_0
]
if
rowcount
is
not
None
:
if
rowcount
is
not
None
:
...
@@ -251,7 +250,7 @@ class SparseStorage(object):
...
@@ -251,7 +250,7 @@ class SparseStorage(object):
colptr
=
torch
.
cat
([
colptr
,
colptr
.
new_full
((
diff_1
,
),
nnz
)])
colptr
=
torch
.
cat
([
colptr
,
colptr
.
new_full
((
diff_1
,
),
nnz
)])
if
colcount
is
not
None
:
if
colcount
is
not
None
:
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
colcount
=
torch
.
cat
([
colcount
,
colcount
.
new_zeros
(
diff_1
)])
el
se
:
el
if
diff_1
<
0
:
if
colptr
is
not
None
:
if
colptr
is
not
None
:
colptr
=
colptr
[:
-
diff_1
]
colptr
=
colptr
[:
-
diff_1
]
if
colcount
is
not
None
:
if
colcount
is
not
None
:
...
@@ -280,7 +279,7 @@ class SparseStorage(object):
...
@@ -280,7 +279,7 @@ class SparseStorage(object):
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
()
+
self
.
col
()
idx
=
self
.
sparse_size
(
1
)
*
self
.
row
()
+
self
.
col
()
row
=
idx
/
num_cols
row
=
idx
/
/
num_cols
col
=
idx
%
num_cols
col
=
idx
%
num_cols
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
assert
row
.
dtype
==
torch
.
long
and
col
.
dtype
==
torch
.
long
...
...
torch_sparse/tensor.py
View file @
3679da2b
...
@@ -397,8 +397,8 @@ class SparseTensor(object):
...
@@ -397,8 +397,8 @@ class SparseTensor(object):
return
mat
return
mat
def
to_torch_sparse_coo_tensor
(
def
to_torch_sparse_coo_tensor
(
self
,
dtype
:
Optional
[
int
]
=
None
self
,
dtype
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
row
,
col
,
value
=
self
.
coo
()
row
,
col
,
value
=
self
.
coo
()
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
...
@@ -503,8 +503,8 @@ SparseTensor.__repr__ = __repr__
...
@@ -503,8 +503,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
# Scipy Conversions ###########################################################
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
scipy
.
sparse
.
csc_matrix
]
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment