Commit 6244606f authored by rusty1s's avatar rusty1s
Browse files

additional dimension arg

parent e36a72ac
......@@ -153,6 +153,7 @@ Matrix product of a sparse matrix with a dense matrix.
* **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - The first dimension of sparse matrix.
* **n** *(int)* - The second dimension of sparse matrix.
* **matrix** *(Tensor)* - The dense matrix.
### Returns
......@@ -169,7 +170,7 @@ index = torch.tensor([[0, 0, 1, 2, 2],
value = torch.Tensor([1, 2, 4, 1, 3])
matrix = torch.Tensor([[1, 4], [2, 5], [3, 6]])
out = spmm(index, value, 3, matrix)
out = spmm(index, value, 3, 3, matrix)
```
```
......
......@@ -15,5 +15,5 @@ def test_spmm(dtype, device):
value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
out = spmm(index, value, 3, x)
out = spmm(index, value, 3, 3, x)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
......@@ -18,5 +18,5 @@ def test_spmm_spspmm(dtype, device):
value = value.requires_grad_(True)
out_index, out_value = spspmm(index, value, index, value, 3, 3, 3)
out = spmm(out_index, out_value, 3, x)
out = spmm(out_index, out_value, 3, 3, x)
assert out.size() == (3, 2)
from torch_scatter import scatter_add
def spmm(index, value, m, matrix):
def spmm(index, value, m, n, matrix):
"""Matrix product of sparse matrix with dense matrix.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix.
:rtype: :class:`Tensor`
"""
assert n == matrix.size(0)
row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment