Commit 01f55729 authored by rusty1s's avatar rusty1s
Browse files

docs

parent 46a9c9ab
...@@ -63,15 +63,15 @@ For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/py ...@@ -63,15 +63,15 @@ For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/py
* **index** *(LongTensor)* - The index tensor of sparse matrix. * **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix. * **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - First dimension of sparse matrix. * **m** *(int)* - The first dimension of sparse matrix.
* **n** *(int)* - Second dimension of sparse matrix. * **n** *(int)* - The second dimension of sparse matrix.
* **op** *(string, optional)* - Scatter operation to use. (default: `"add"`) * **op** *(string, optional)* - The scatter operation to use. (default: `"add"`)
* **fill_value** *(int, optional)* - Initial fill value of scatter operation. (default: `0`) * **fill_value** *(int, optional)* - The initial fill value of scatter operation. (default: `0`)
### Returns ### Returns
* **index** *(LongTensor)* - Coalesced index tensor of sparse matrix. * **index** *(LongTensor)* - The coalesced index tensor of sparse matrix.
* **value** *(Tensor)* - Coalesced value tensor of sparse matrix. * **value** *(Tensor)* - The coalesced value tensor of sparse matrix.
### Example ### Example
...@@ -105,13 +105,13 @@ Transposes dimensions 0 and 1 of a sparse matrix. ...@@ -105,13 +105,13 @@ Transposes dimensions 0 and 1 of a sparse matrix.
* **index** *(LongTensor)* - The index tensor of sparse matrix. * **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix. * **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - First dimension of sparse matrix. * **m** *(int)* - The first dimension of sparse matrix.
* **n** *(int)* - Second dimension of sparse matrix. * **n** *(int)* - The second dimension of sparse matrix.
### Returns ### Returns
* **index** *(LongTensor)* - Transposed index tensor of sparse matrix. * **index** *(LongTensor)* - The transposed index tensor of sparse matrix.
* **value** *(Tensor)* - Transposed value tensor of sparse matrix. * **value** *(Tensor)* - The transposed value tensor of sparse matrix.
### Example ### Example
...@@ -122,7 +122,7 @@ index = torch.tensor([[1, 0, 1, 0, 2, 1], ...@@ -122,7 +122,7 @@ index = torch.tensor([[1, 0, 1, 0, 2, 1],
[0, 1, 1, 1, 0, 0]]) [0, 1, 1, 1, 0, 0]])
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]]) value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = transpose(index, value, m=3, n=2) index, value = transpose(index, value, 3, 2)
``` ```
``` ```
...@@ -148,12 +148,12 @@ Matrix product of a sparse matrix with a dense matrix. ...@@ -148,12 +148,12 @@ Matrix product of a sparse matrix with a dense matrix.
* **index** *(LongTensor)* - The index tensor of sparse matrix. * **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix. * **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - First dimension of sparse matrix. * **m** *(int)* - The first dimension of sparse matrix.
* **matrix** *(int)* - Dense matrix. * **matrix** *(Tensor)* - The dense matrix.
### Returns ### Returns
* **out** *(Tensor)* - Dense output matrix. * **out** *(Tensor)* - The dense output matrix.
### Example ### Example
...@@ -190,14 +190,14 @@ Both input sparse matrices need to be **coalesced**. ...@@ -190,14 +190,14 @@ Both input sparse matrices need to be **coalesced**.
* **valueA** *(Tensor)* - The value tensor of first sparse matrix. * **valueA** *(Tensor)* - The value tensor of first sparse matrix.
* **indexB** *(LongTensor)* - The index tensor of second sparse matrix. * **indexB** *(LongTensor)* - The index tensor of second sparse matrix.
* **valueB** *(Tensor)* - The value tensor of second sparse matrix. * **valueB** *(Tensor)* - The value tensor of second sparse matrix.
* **m** *(int)* - First dimension of first sparse matrix. * **m** *(int)* - The first dimension of first sparse matrix.
* **k** *(int)* - Second dimension of first sparse matrix and first dimension of second sparse matrix. * **k** *(int)* - The second dimension of first sparse matrix and first dimension of second sparse matrix.
* **n** *(int)* - Second dimension of second sparse matrix. * **n** *(int)* - The second dimension of second sparse matrix.
### Returns ### Returns
* **index** *(LongTensor)* - Output index tensor of sparse matrix. * **index** *(LongTensor)* - The output index tensor of sparse matrix.
* **value** *(Tensor)* - Output value tensor of sparse matrix. * **value** *(Tensor)* - The output value tensor of sparse matrix.
### Example ### Example
......
...@@ -3,7 +3,23 @@ import torch_scatter ...@@ -3,7 +3,23 @@ import torch_scatter
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
"""Row-wise reorders and removes duplicate entries in sparse matrix.""" """Row-wise sorts :obj:`value` and removes duplicate entries. Duplicate
entries are removed by scattering them together. For scattering, any
operation of `"torch_scatter"<https://github.com/rusty1s/pytorch_scatter>`_
can be used.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
op (string, optional): The scatter operation to use. (default:
:obj:`"add"`)
fill_value (int, optional): The initial fill value of scatter
operation. (default: :obj:`0`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
row, col = index row, col = index
......
...@@ -2,7 +2,16 @@ from torch_scatter import scatter_add ...@@ -2,7 +2,16 @@ from torch_scatter import scatter_add
def spmm(index, value, m, matrix): def spmm(index, value, m, matrix):
"""Matrix product of sparse matrix with dense matrix.""" """Matrix product of sparse matrix with dense matrix.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
matrix (:class:`Tensor`): The dense matrix.
:rtype: :class:`Tensor`
"""
row, col = index row, col = index
matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1) matrix = matrix if matrix.dim() > 1 else matrix.unsqueeze(-1)
......
...@@ -8,7 +8,21 @@ if torch.cuda.is_available(): ...@@ -8,7 +8,21 @@ if torch.cuda.is_available():
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
"""Sparse matrix product of two sparse matrices with autograd support.""" """Matrix product of two sparse tensors. Both input sparse matrices need to
be coalesced.
Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
m (int): The first dimension of first sparse matrix.
k (int): The second dimension of first sparse matrix and first
dimension of second sparse matrix.
n (int): The second dimension of second sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
@staticmethod @staticmethod
def forward(ctx, indexA, valueA, indexB, valueB, m, k, n): def forward(ctx, indexA, valueA, indexB, valueB, m, k, n):
......
...@@ -3,7 +3,16 @@ from torch_sparse import coalesce ...@@ -3,7 +3,16 @@ from torch_sparse import coalesce
def transpose(index, value, m, n): def transpose(index, value, m, n):
"""Transpose of sparse matrix.""" """Transposes dimensions 0 and 1 of a sparse matrix.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
row, col = index row, col = index
index = torch.stack([col, row], dim=0) index = torch.stack([col, row], dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment