Commit c2321536 authored by rusty1s's avatar rusty1s
Browse files

docs

parent 53abd36b
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
This package consists of a small extension library of optimized sparse matrix operations for the use in [PyTorch](http://pytorch.org/), which are missing and or lack autograd support in the main package. This package consists of a small extension library of optimized sparse matrix operations for the use in [PyTorch](http://pytorch.org/), which are missing and or lack autograd support in the main package.
This package currently consists of the following methods: This package currently consists of the following methods:
* **[Autograd Sparse Tensor Creation](#Autograd Sparse Tensor Creation)** * **[Autograd Sparse Tensor Creation](#autograd-sparse-tensor-creation)**
* **[Autograd Sparse Tensor Value Extraction](#Autograd Sparse Tensor Value Extraction)** * **[Autograd Sparse Tensor Value Extraction](#autograd-sparse-tensor-value-extraction)**
* **[Sparse Sparse Matrix Multiplication](#Sparse Sparse Matrix Multiplication)** * **[Sparse Sparse Matrix Multiplication](#sparse-sparse-matrix-multiplication)**
All included operations work on varying data types and are implemented both for CPU and GPU. All included operations work on varying data types and are implemented both for CPU and GPU.
...@@ -47,10 +47,60 @@ If you are running into any installation problems, please follow these [instruct ...@@ -47,10 +47,60 @@ If you are running into any installation problems, please follow these [instruct
## Autograd Sparse Tensor Creation ## Autograd Sparse Tensor Creation
```
torch_sparse.sparse_coo_tensor(torch.LongTensor, torch.Tensor, torch.Size) -> torch.SparseTensor
```
Constructs a [`torch.SparseTensor`](https://pytorch.org/docs/stable/sparse.html) with autograd capabilities w.r.t. `value`.
```python
from torch_sparse import sparse_coo_tensor
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.Tensor([3, 4, 5], requires_grad=True)
A = sparse_coo_tensor(i, v, torch.Size([2,3]))
```
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674).
## Autograd Sparse Tensor Value Extraction ## Autograd Sparse Tensor Value Extraction
```
torch_sparse.to_value(SparseTensor) --> Tensor
```
Wrapper method to support autograd on values of sparse tensors.
```python
from torch_sparse import to_value
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.Tensor([3, 4, 5], requires_grad=True)
A = torch.sparse_coo_tensor(i, v, torch.Size([2,3]), requires_grad=True)
v = to_value(A)
```
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674).
## Sparse Sparse Matrix Multiplication ## Sparse Sparse Matrix Multiplication
```
torch_sparse.spspmm(SparseTensor, SparseTensor) --> SparseTensor
```
Sparse matrix product of two sparse tensors with autograd support.
```
from torch_sparse import spspmm
A = torch.sparse_coo_tensor(..., requries_grad=True)
B = torch.sparse_coo_tensor(..., requries_grad=True)
C = spspmm(A, B)
```
## Running tests ## Running tests
``` ```
......
...@@ -2,7 +2,7 @@ from itertools import product ...@@ -2,7 +2,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch_sparse import SparseTensor, spspmm, to_value from torch_sparse import sparse_coo_tensor, spspmm, to_value
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
...@@ -30,13 +30,13 @@ def test_spspmm(test, dtype, device): ...@@ -30,13 +30,13 @@ def test_spspmm(test, dtype, device):
indexA = torch.tensor(test['indexA'], device=device) indexA = torch.tensor(test['indexA'], device=device)
valueA = tensor(test['valueA'], dtype, device, requires_grad=True) valueA = tensor(test['valueA'], dtype, device, requires_grad=True)
sizeA = torch.Size(test['sizeA']) sizeA = torch.Size(test['sizeA'])
A = SparseTensor(indexA, valueA, sizeA) A = sparse_coo_tensor(indexA, valueA, sizeA)
denseA = A.detach().to_dense().requires_grad_() denseA = A.detach().to_dense().requires_grad_()
indexB = torch.tensor(test['indexB'], device=device) indexB = torch.tensor(test['indexB'], device=device)
valueB = tensor(test['valueB'], dtype, device, requires_grad=True) valueB = tensor(test['valueB'], dtype, device, requires_grad=True)
sizeB = torch.Size(test['sizeB']) sizeB = torch.Size(test['sizeB'])
B = SparseTensor(indexB, valueB, sizeB) B = sparse_coo_tensor(indexB, valueB, sizeB)
denseB = B.detach().to_dense().requires_grad_() denseB = B.detach().to_dense().requires_grad_()
C = spspmm(A, B) C = spspmm(A, B)
......
from .sparse import SparseTensor, to_value from .sparse import sparse_coo_tensor, to_value
from .matmul import spspmm from .matmul import spspmm
__all__ = [ __all__ = [
'SparseTensor', 'sparse_coo_tensor',
'to_value', 'to_value',
'spspmm', 'spspmm',
] ]
...@@ -6,6 +6,8 @@ if torch.cuda.is_available(): ...@@ -6,6 +6,8 @@ if torch.cuda.is_available():
class SpSpMM(torch.autograd.Function): class SpSpMM(torch.autograd.Function):
"""Sparse matrix product of two sparse tensors with autograd support."""
@staticmethod @staticmethod
def forward(ctx, A, B): def forward(ctx, A, B):
ctx.save_for_backward(A, B) ctx.save_for_backward(A, B)
......
import torch import torch
class _SparseTensor(torch.autograd.Function): class SparseCooTensor(torch.autograd.Function):
"""Constructs Sparse matrix with autograd capabilities w.r.t. to value."""
@staticmethod @staticmethod
def forward(ctx, index, value, size): def forward(ctx, index, value, size):
ctx.size = size ctx.size = size
...@@ -26,10 +28,12 @@ class _SparseTensor(torch.autograd.Function): ...@@ -26,10 +28,12 @@ class _SparseTensor(torch.autograd.Function):
return None, grad_in, None return None, grad_in, None
SparseTensor = _SparseTensor.apply sparse_coo_tensor = SparseCooTensor.apply
class ToValue(torch.autograd.Function): class ToValue(torch.autograd.Function):
"""Extract values of sparse tensors with autograd support."""
@staticmethod @staticmethod
def forward(ctx, A): def forward(ctx, A):
ctx.save_for_backward(A) ctx.save_for_backward(A)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment