Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
6244606f
Commit
6244606f
authored
Jun 16, 2019
by
rusty1s
Browse files
additional dimension arg
parent
e36a72ac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
4 deletions
+8
-4
README.md
README.md
+2
-1
test/test_spmm.py
test/test_spmm.py
+1
-1
test/test_spspmm_spmm.py
test/test_spspmm_spmm.py
+1
-1
torch_sparse/spmm.py
torch_sparse/spmm.py
+4
-1
No files found.
README.md
View file @
6244606f
...
...
@@ -153,6 +153,7 @@ Matrix product of a sparse matrix with a dense matrix.
*
**index**
*(LongTensor)*
- The index tensor of sparse matrix.
*
**value**
*(Tensor)*
- The value tensor of sparse matrix.
*
**m**
*(int)*
- The first dimension of sparse matrix.
*
**n**
*(int)*
- The second dimension of sparse matrix.
*
**matrix**
*(Tensor)*
- The dense matrix.
### Returns
...
...
@@ -169,7 +170,7 @@ index = torch.tensor([[0, 0, 1, 2, 2],
value
=
torch
.
Tensor
([
1
,
2
,
4
,
1
,
3
])
matrix
=
torch
.
Tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]])
out
=
spmm
(
index
,
value
,
3
,
matrix
)
out
=
spmm
(
index
,
value
,
3
,
3
,
matrix
)
```
```
...
...
test/test_spmm.py
View file @
6244606f
...
...
@@ -15,5 +15,5 @@ def test_spmm(dtype, device):
value
=
tensor
([
1
,
2
,
4
,
1
,
3
],
dtype
,
device
)
x
=
tensor
([[
1
,
4
],
[
2
,
5
],
[
3
,
6
]],
dtype
,
device
)
out
=
spmm
(
index
,
value
,
3
,
x
)
out
=
spmm
(
index
,
value
,
3
,
3
,
x
)
assert
out
.
tolist
()
==
[[
7
,
16
],
[
8
,
20
],
[
7
,
19
]]
test/test_spspmm_spmm.py
View file @
6244606f
...
...
@@ -18,5 +18,5 @@ def test_spmm_spspmm(dtype, device):
value
=
value
.
requires_grad_
(
True
)
out_index
,
out_value
=
spspmm
(
index
,
value
,
index
,
value
,
3
,
3
,
3
)
out
=
spmm
(
out_index
,
out_value
,
3
,
x
)
out
=
spmm
(
out_index
,
out_value
,
3
,
3
,
x
)
assert
out
.
size
()
==
(
3
,
2
)
torch_sparse/spmm.py
View file @
6244606f
from
torch_scatter
import
scatter_add
def
spmm
(
index
,
value
,
m
,
matrix
):
def
spmm
(
index
,
value
,
m
,
n
,
matrix
):
"""Matrix product of sparse matrix with dense matrix.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix.
:rtype: :class:`Tensor`
"""
assert
n
==
matrix
.
size
(
0
)
row
,
col
=
index
matrix
=
matrix
if
matrix
.
dim
()
>
1
else
matrix
.
unsqueeze
(
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment