Commit 6244606f authored by rusty1s's avatar rusty1s
Browse files

additional dimension arg

parent e36a72ac
...@@ -153,6 +153,7 @@ Matrix product of a sparse matrix with a dense matrix. ...@@ -153,6 +153,7 @@ Matrix product of a sparse matrix with a dense matrix.
* **index** *(LongTensor)* - The index tensor of sparse matrix. * **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix. * **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - The first dimension of sparse matrix. * **m** *(int)* - The first dimension of sparse matrix.
* **n** *(int)* - The second dimension of sparse matrix.
* **matrix** *(Tensor)* - The dense matrix. * **matrix** *(Tensor)* - The dense matrix.
### Returns ### Returns
...@@ -169,7 +170,7 @@ index = torch.tensor([[0, 0, 1, 2, 2], ...@@ -169,7 +170,7 @@ index = torch.tensor([[0, 0, 1, 2, 2],
value = torch.Tensor([1, 2, 4, 1, 3]) value = torch.Tensor([1, 2, 4, 1, 3])
matrix = torch.Tensor([[1, 4], [2, 5], [3, 6]]) matrix = torch.Tensor([[1, 4], [2, 5], [3, 6]])
out = spmm(index, value, 3, matrix) out = spmm(index, value, 3, 3, matrix)
``` ```
``` ```
......
...@@ -15,5 +15,5 @@ def test_spmm(dtype, device): ...@@ -15,5 +15,5 @@ def test_spmm(dtype, device):
value = tensor([1, 2, 4, 1, 3], dtype, device) value = tensor([1, 2, 4, 1, 3], dtype, device)
x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device) x = tensor([[1, 4], [2, 5], [3, 6]], dtype, device)
out = spmm(index, value, 3, x) out = spmm(index, value, 3, 3, x)
assert out.tolist() == [[7, 16], [8, 20], [7, 19]] assert out.tolist() == [[7, 16], [8, 20], [7, 19]]
...@@ -18,5 +18,5 @@ def test_spmm_spspmm(dtype, device): ...@@ -18,5 +18,5 @@ def test_spmm_spspmm(dtype, device):
value = value.requires_grad_(True) value = value.requires_grad_(True)
out_index, out_value = spspmm(index, value, index, value, 3, 3, 3) out_index, out_value = spspmm(index, value, index, value, 3, 3, 3)
out = spmm(out_index, out_value, 3, x) out = spmm(out_index, out_value, 3, 3, x)
assert out.size() == (3, 2) assert out.size() == (3, 2)
from torch_scatter import scatter_add from torch_scatter import scatter_add
def spmm(index, value, m, matrix): def spmm(index, value, m, n, matrix):
"""Matrix product of sparse matrix with dense matrix. """Matrix product of sparse matrix with dense matrix.
Args: Args:
index (:class:`LongTensor`): The index tensor of sparse matrix. index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix. value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix. m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix. matrix (:class:`Tensor`): The dense matrix.
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
assert n == matrix.size(0)
row, col = index row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment