Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
ee2d323d
Commit
ee2d323d
authored
Mar 11, 2019
by
rusty1s
Browse files
fixed 'index derivative is not defined' message
parent
c1cd9753
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
7 deletions
+36
-7
test/test_spmm.py
test/test_spmm.py
+12
-6
test/test_spspmm_spmm.py
test/test_spspmm_spmm.py
+22
-0
torch_sparse/spspmm.py
torch_sparse/spspmm.py
+2
-1
No files found.
test/test_spmm.py
View file @
ee2d323d
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse
import
spmm
from
.utils
import
dtypes
,
devices
,
tensor
def
test_spmm
():
row
=
torch
.
tensor
([
0
,
0
,
1
,
2
,
2
])
col
=
torch
.
tensor
([
0
,
2
,
1
,
0
,
1
])
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
2
,
2
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
2
,
1
,
0
,
1
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
torch
.
tensor
([
1
,
2
,
4
,
1
,
3
])
value
=
tensor
([
1
,
2
,
4
,
1
,
3
],
dtype
,
device
)
x
=
tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]],
dtype
,
device
)
matrix
=
torch
.
tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]])
out
=
spmm
(
index
,
value
,
3
,
matrix
)
out
=
spmm
(
index
,
value
,
3
,
x
)
assert
out
.
tolist
()
==
[[
7
,
16
],
[
8
,
20
],
[
7
,
19
]]
test/test_spspmm_spmm.py
0 → 100644
View file @
ee2d323d
from
itertools
import
product
import
pytest
import
torch
from
torch_sparse
import
spspmm
,
spmm
from
.utils
import
dtypes
,
devices
,
tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_spmm_spspmm
(
dtype
,
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
2
,
2
],
device
=
device
)
col
=
torch
.
tensor
([
0
,
2
,
1
,
0
,
1
],
device
=
device
)
index
=
torch
.
stack
([
row
,
col
],
dim
=
0
)
value
=
tensor
([
1
,
2
,
4
,
1
,
3
],
dtype
,
device
)
x
=
tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]],
dtype
,
device
)
value
=
value
.
requires_grad_
(
True
)
out_index
,
out_value
=
spspmm
(
index
,
value
,
index
,
value
,
3
,
3
,
3
)
out
=
spmm
(
out_index
,
out_value
,
3
,
x
)
assert
out
.
size
()
==
(
3
,
2
)
torch_sparse/spspmm.py
View file @
ee2d323d
...
...
@@ -23,7 +23,8 @@ def spspmm(indexA, valueA, indexB, valueB, m, k, n):
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
return
SpSpMM
.
apply
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
index
,
value
=
SpSpMM
.
apply
(
indexA
,
valueA
,
indexB
,
valueB
,
m
,
k
,
n
)
return
index
.
detach
(),
value
class
SpSpMM
(
torch
.
autograd
.
Function
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment