Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
2593c925
Unverified
Commit
2593c925
authored
Aug 17, 2023
by
Andrei Ivanov
Committed by
GitHub
Aug 18, 2023
Browse files
Improving sparse tests. (#6168)
Co-authored-by:
Hongzhi (Steve), Chen
<
chenhongzhi.nkcs@gmail.com
>
parent
5a417414
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
6 deletions
+18
-6
tests/python/pytorch/sparse/test_matmul.py
tests/python/pytorch/sparse/test_matmul.py
+10
-4
tests/python/pytorch/sparse/test_sparse_matrix.py
tests/python/pytorch/sparse/test_sparse_matrix.py
+8
-2
No files found.
tests/python/pytorch/sparse/test_matmul.py
View file @
2593c925
import
sy
s
import
warning
s
import
backend
as
F
import
pytest
...
...
@@ -19,6 +19,12 @@ from .utils import (
)
def
_torch_sparse_mm
(
torch_A1
,
torch_A2
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
return
torch
.
sparse
.
mm
(
torch_A1
,
torch_A2
)
@
pytest
.
mark
.
parametrize
(
"create_func"
,
[
rand_coo
,
rand_csr
,
rand_csc
])
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
2
,
7
),
(
5
,
2
)])
@
pytest
.
mark
.
parametrize
(
"nnz"
,
[
1
,
10
])
...
...
@@ -98,7 +104,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
torch_A1
=
sparse_matrix_to_torch_sparse
(
A1
)
torch_A2
=
sparse_matrix_to_torch_sparse
(
A2
)
torch_A3
=
torch
.
sparse
.
mm
(
torch_A1
,
torch_A2
)
torch_A3
=
_
torch
_
sparse
_
mm
(
torch_A1
,
torch_A2
)
torch_A3_grad
=
sparse_matrix_to_torch_sparse
(
A3
,
grad
)
torch_A3
.
backward
(
torch_A3_grad
)
...
...
@@ -161,7 +167,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
torch_A
=
sparse_matrix_to_torch_sparse
(
A
)
torch_D
=
sparse_matrix_to_torch_sparse
(
D
)
torch_B
=
torch
.
sparse
.
mm
(
torch_A
,
torch_D
)
torch_B
=
_
torch
_
sparse
_
mm
(
torch_A
,
torch_D
)
torch_B_grad
=
sparse_matrix_to_torch_sparse
(
B
,
grad
)
torch_B
.
backward
(
torch_B_grad
)
...
...
@@ -194,7 +200,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
torch_A
=
sparse_matrix_to_torch_sparse
(
A
)
torch_D
=
sparse_matrix_to_torch_sparse
(
D
)
torch_B
=
torch
.
sparse
.
mm
(
torch_D
,
torch_A
)
torch_B
=
_
torch
_
sparse
_
mm
(
torch_D
,
torch_A
)
torch_B_grad
=
sparse_matrix_to_torch_sparse
(
B
,
grad
)
torch_B
.
backward
(
torch_B_grad
)
...
...
tests/python/pytorch/sparse/test_sparse_matrix.py
View file @
2593c925
import
sys
import
unittest
import
warnings
import
backend
as
F
import
pytest
...
...
@@ -19,6 +19,12 @@ from dgl.sparse import (
)
def
_torch_sparse_csr_tensor
(
indptr
,
indices
,
val
,
torch_sparse_shape
):
with
warnings
.
catch_warnings
():
warnings
.
simplefilter
(
"ignore"
,
category
=
UserWarning
)
return
torch
.
sparse_csr_tensor
(
indptr
,
indices
,
val
,
torch_sparse_shape
)
@
pytest
.
mark
.
parametrize
(
"dense_dim"
,
[
None
,
4
])
@
pytest
.
mark
.
parametrize
(
"row"
,
[(
0
,
0
,
1
,
2
),
(
0
,
1
,
2
,
4
)])
@
pytest
.
mark
.
parametrize
(
"col"
,
[(
0
,
1
,
2
,
2
),
(
1
,
3
,
3
,
4
)])
...
...
@@ -580,7 +586,7 @@ def test_torch_sparse_csr_conversion(indptr, indices, shape):
torch_sparse_shape
=
shape
val_shape
=
(
indices
.
shape
[
0
],)
val
=
torch
.
randn
(
val_shape
).
to
(
dev
)
torch_sparse_csr
=
torch
.
sparse_csr_tensor
(
torch_sparse_csr
=
_
torch
_
sparse_csr_tensor
(
indptr
,
indices
,
val
,
torch_sparse_shape
)
spmat
=
from_torch_sparse
(
torch_sparse_csr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment