Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
2317ff66
"...text-generation-inference.git" did not exist on "90b226db291769a45ecbccaa4f7384bc6b9bff8a"
Commit
2317ff66
authored
May 09, 2020
by
Mario Geiger
Browse files
test_spspmm_2
parent
9f034684
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
1 deletion
+33
-1
test/test_spspmm.py
test/test_spspmm.py
+33
-1
No files found.
test/test_spspmm.py
View file @
2317ff66
...
@@ -2,7 +2,7 @@ from itertools import product
...
@@ -2,7 +2,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse
import
spspmm
,
SparseTensor
from
torch_sparse
import
spspmm
,
SparseTensor
,
transpose
from
.utils
import
grad_dtypes
,
devices
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
...
@@ -19,6 +19,38 @@ def test_spspmm(dtype, device):
...
@@ -19,6 +19,38 @@ def test_spspmm(dtype, device):
assert
valueC
.
tolist
()
==
[
8
,
6
,
8
]
assert
valueC
.
tolist
()
==
[
8
,
6
,
8
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_spspmm_2
(
dtype
,
device
):
row
=
torch
.
tensor
(
[
0
,
1
,
1
,
1
,
2
,
3
,
4
,
5
,
5
,
6
,
6
,
7
,
7
,
7
,
8
,
8
,
9
,
9
],
device
=
device
)
col
=
torch
.
tensor
(
[
0
,
5
,
10
,
15
,
1
,
2
,
3
,
7
,
13
,
6
,
9
,
5
,
10
,
15
,
11
,
14
,
5
,
15
],
device
=
device
)
value
=
torch
.
tensor
(
[
1
,
3
**-
0.5
,
3
**-
0.5
,
3
**-
0.5
,
1
,
1
,
1
,
-
2
**-
0.5
,
-
2
**-
0.5
,
-
2
**-
0.5
,
-
2
**-
0.5
,
6
**-
0.5
,
-
6
**
0.5
/
3
,
6
**-
0.5
,
-
2
**-
0.5
,
-
2
**-
0.5
,
2
**-
0.5
,
-
2
**-
0.5
],
dtype
=
dtype
,
device
=
device
)
index
=
torch
.
stack
([
row
,
col
])
m
=
value
.
new_zeros
(
10
,
16
)
m
[
index
[
0
],
index
[
1
]]
=
value
index_t
,
value_t
=
transpose
(
index
,
value
,
10
,
16
)
index
,
value
=
spspmm
(
index
,
value
,
index_t
,
value_t
,
10
,
16
,
10
)
mask
=
value
.
abs
()
>
1e-4
index
,
value
=
index
[:,
mask
],
value
[
mask
]
assert
index
.
tolist
()
==
[[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
],
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]]
assert
value
.
tolist
()
==
[
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_sparse_tensor_spspmm
(
dtype
,
device
):
def
test_sparse_tensor_spspmm
(
dtype
,
device
):
x
=
SparseTensor
(
x
=
SparseTensor
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment