Commit 46a9c9ab authored by rusty1s's avatar rusty1s
Browse files

added docs

parent 37d18b5c
...@@ -13,14 +13,18 @@ ...@@ -13,14 +13,18 @@
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
This package consists of a small extension library of optimized sparse matrix operations for the use in [PyTorch](http://pytorch.org/), which are missing and or lack autograd support in the main package. [PyTorch](http://pytorch.org/) (<= 0.4.1) completely lacks autograd support and operations such as sparse sparse matrix multiplication, but is heavily working on improvement (*cf.* [this issue](https://github.com/pytorch/pytorch/issues/9674)).
In the meantime, this package consists of a small extension library of optimized sparse matrix operations with autograd support.
This package currently consists of the following methods: This package currently consists of the following methods:
* **[Autograd Sparse Tensor Creation](#autograd-sparse-tensor-creation)** * **[Coalesce](#coalesce)**
* **[Autograd Sparse Tensor Value Extraction](#autograd-sparse-tensor-value-extraction)** * **[Transpose](#transpose)**
* **[Sparse Dense Matrix Multiplication](#sparse-dense-matrix-multiplication)**
* **[Sparse Sparse Matrix Multiplication](#sparse-sparse-matrix-multiplication)** * **[Sparse Sparse Matrix Multiplication](#sparse-sparse-matrix-multiplication)**
All included operations work on varying data types and are implemented both for CPU and GPU. All included operations work on varying data types and are implemented both for CPU and GPU.
To avoid the hazzle of creating [`torch.sparse_coo_tensor`](https://pytorch.org/docs/stable/torch.html?highlight=sparse_coo_tensor#torch.sparse_coo_tensor), this package defines operations on sparse tensors by simply passing `index` and `value` tensors as arguments ([with same shapes as defined in PyTorch](https://pytorch.org/docs/stable/sparse.html)).
Note that only `value` comes with autograd support, as `index` is discrete and therefore not differentiable.
## Installation ## Installation
...@@ -45,60 +49,176 @@ pip install torch-scatter torch-sparse ...@@ -45,60 +49,176 @@ pip install torch-scatter torch-sparse
If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_sparse/issues). If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_sparse/issues).
## Autograd Sparse Tensor Creation ## Coalesce
``` ```
torch_sparse.sparse_coo_tensor(torch.LongTensor, torch.Tensor, torch.Size) -> torch.SparseTensor torch_sparse.coalesce(index, value, m, n, op="add", fill_value=0) -> (torch.LongTensor, torch.Tensor)
``` ```
Constructs a [`torch.SparseTensor`](https://pytorch.org/docs/stable/sparse.html) with autograd capabilities w.r.t. `value`. Row-wise sorts `value` and removes duplicate entries.
Duplicate entries are removed by scattering them together.
For scattering, any operation of [`torch_scatter`](https://github.com/rusty1s/pytorch_scatter) can be used.
### Parameters
* **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - First dimension of sparse matrix.
* **n** *(int)* - Second dimension of sparse matrix.
* **op** *(string, optional)* - Scatter operation to use. (default: `"add"`)
* **fill_value** *(int, optional)* - Initial fill value of scatter operation. (default: `0`)
### Returns
* **index** *(LongTensor)* - Coalesced index tensor of sparse matrix.
* **value** *(Tensor)* - Coalesced value tensor of sparse matrix.
### Example
```python ```python
from torch_sparse import sparse_coo_tensor from torch_sparse import coalesce
index = torch.tensor([[1, 0, 1, 0, 2, 1],
[0, 1, 1, 1, 0, 0]])
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
i = torch.tensor([[0, 1, 1], index, value = coalesce(index, value, m=3, n=2)
[2, 0, 2]])
v = torch.Tensor([3, 4, 5], requires_grad=True)
A = sparse_coo_tensor(i, v, torch.Size([2,3]))
``` ```
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674). ```
print(index)
tensor([[0, 1, 1, 2],
[1, 0, 1, 0]])
print(value)
tensor([[6, 8], [7, 9], [3, 4], [5, 6]])
```
## Autograd Sparse Tensor Value Extraction ## Transpose
``` ```
torch_sparse.to_value(torch.SparseTensor) -> torch.Tensor torch_sparse.transpose(index, value, m, n) -> (torch.LongTensor, torch.Tensor)
``` ```
Wrapper method to support autograd on values of [`torch.SparseTensor`](https://pytorch.org/docs/stable/sparse.html). Transposes dimensions 0 and 1 of a sparse matrix.
### Parameters
* **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - First dimension of sparse matrix.
* **n** *(int)* - Second dimension of sparse matrix.
### Returns
* **index** *(LongTensor)* - Transposed index tensor of sparse matrix.
* **value** *(Tensor)* - Transposed value tensor of sparse matrix.
### Example
```python ```python
from torch_sparse import to_value from torch_sparse import transpose
index = torch.tensor([[1, 0, 1, 0, 2, 1],
[0, 1, 1, 1, 0, 0]])
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
i = torch.tensor([[0, 1, 1], index, value = transpose(index, value, m=3, n=2)
[2, 0, 2]])
v = torch.Tensor([3, 4, 5], requires_grad=True)
A = torch.sparse_coo_tensor(i, v, torch.Size([2,3]), requires_grad=True)
v = to_value(A)
``` ```
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674). ```
print(index)
tensor([[0, 0, 1, 1],
[1, 2, 0, 1]])
print(value)
tensor([[7, 9],
[5, 6],
[6, 8],
[3, 4]])
```
## Sparse Sparse Matrix Multiplication ## Sparse Dense Matrix Multiplication
``` ```
torch_sparse.spspmm(torch.SparseTensor, torch.SparseTensor) -> torch.SparseTensor torch_sparse.spmm(index, value, m, matrix) -> torch.Tensor
``` ```
Sparse matrix product of two sparse tensors with autograd support. Matrix product of a sparse matrix with a dense matrix.
### Parameters
* **index** *(LongTensor)* - The index tensor of sparse matrix.
* **value** *(Tensor)* - The value tensor of sparse matrix.
* **m** *(int)* - First dimension of sparse matrix.
* **matrix** *(int)* - Dense matrix.
### Returns
* **out** *(Tensor)* - Dense output matrix.
### Example
```python
from torch_sparse import spmm
index = torch.tensor([[0, 0, 1, 2, 2],
[0, 2, 1, 0, 1]])
value = torch.tensor([1, 2, 4, 1, 3])
matrix = torch.tensor([[1, 4], [2, 5], [3, 6]])
out = spmm(index, value, 3, matrix)
```
```
print(out)
tensor([[7, 16],
[8, 20],
[7, 19]])
``` ```
## Sparse Sparse Matrix Multiplication
```
torch_sparse.spspmm(indexA, valueA, indexB, valueB, m, k, n) -> (torch.LongTensor, torch.Tensor)
```
Matrix product of two sparse tensors.
Both input sparse matrices need to be **coalesced**.
### Parameters
* **indexA** *(LongTensor)* - The index tensor of first sparse matrix.
* **valueA** *(Tensor)* - The value tensor of first sparse matrix.
* **indexB** *(LongTensor)* - The index tensor of second sparse matrix.
* **valueB** *(Tensor)* - The value tensor of second sparse matrix.
* **m** *(int)* - First dimension of first sparse matrix.
* **k** *(int)* - Second dimension of first sparse matrix and first dimension of second sparse matrix.
* **n** *(int)* - Second dimension of second sparse matrix.
### Returns
* **index** *(LongTensor)* - Output index tensor of sparse matrix.
* **value** *(Tensor)* - Output value tensor of sparse matrix.
### Example
```python
from torch_sparse import spspmm from torch_sparse import spspmm
A = torch.sparse_coo_tensor(..., requries_grad=True) indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]])
B = torch.sparse_coo_tensor(..., requries_grad=True) valueA = torch.tensor([1, 2, 3, 4, 5])
indexB = torch.tensor([[0, 2], [1, 0]])
valueB = torch.tensor([2, 4])
C = spspmm(A, B) indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2)
```
```
print(index)
tensor([[0, 1, 2],
[0, 1, 1]])
print(value)
tensor([8, 6, 8])
``` ```
## Running tests ## Running tests
......
...@@ -3,7 +3,7 @@ import torch_scatter ...@@ -3,7 +3,7 @@ import torch_scatter
def coalesce(index, value, m, n, op='add', fill_value=0): def coalesce(index, value, m, n, op='add', fill_value=0):
"""Row-wise reorders and removes duplicate entries in sparse matrixx.""" """Row-wise reorders and removes duplicate entries in sparse matrix."""
row, col = index row, col = index
...@@ -16,5 +16,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0): ...@@ -16,5 +16,7 @@ def coalesce(index, value, m, n, op='add', fill_value=0):
if value is not None: if value is not None:
op = getattr(torch_scatter, 'scatter_{}'.format(op)) op = getattr(torch_scatter, 'scatter_{}'.format(op))
value = op(value, inv, 0, None, perm.size(0), fill_value) value = op(value, inv, 0, None, perm.size(0), fill_value)
if isinstance(value, tuple):
value = value[0]
return index, value return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment