Commit c3155ab6 authored by rusty1s's avatar rusty1s
Browse files

coalesce

parent ac04c1b6
......@@ -24,11 +24,11 @@ All included operations work on varying data types and are implemented both for
## Installation
Ensure that at least PyTorch 0.4.0 is installed and verify that `cuda/bin` and `cuda/install` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 0.4.1 is installed and verify that `cuda/bin` and `cuda/install` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 0.4.0
>>> 0.4.1
$ echo $PATH
>>> /usr/local/cuda/bin:...
......@@ -40,7 +40,7 @@ $ echo $CPATH
Then run:
```
pip install torch-sparse
pip install torch-scatter torch-sparse
```
If you are running into any installation problems, please follow these [instructions](https://rusty1s.github.io/pytorch_geometric/build/html/notes/installation.html) first before creating an [issue](https://github.com/rusty1s/pytorch_sparse/issues).
......
import torch
from torch_sparse import coalesce
def test_coalesce():
row = torch.tensor([1, 0, 1, 0, 2, 1])
col = torch.tensor([0, 1, 1, 1, 0, 0])
index = torch.stack([row, col], dim=0)
value = torch.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]])
index, value = coalesce(index, value, torch.Size([4, 2]))
assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]]
assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]]
from .coalesce import coalesce
from .sparse import sparse_coo_tensor, to_value
from .matmul import spspmm
__all__ = [
'coalesce',
'sparse_coo_tensor',
'to_value',
'spspmm',
......
import torch
import torch_scatter
def coalesce(index, value, size, op='add', fill_value=0):
m, n = size
row, col = index
index = row * n + col
unique, inv = torch.unique(index, sorted=True, return_inverse=True)
perm = torch.arange(index.size(0), dtype=index.dtype, device=index.device)
perm = index.new_empty(inv.max().item() + 1).scatter_(0, inv, perm)
index = torch.stack([row[perm], col[perm]], dim=0)
if value is not None:
scatter = getattr(torch_scatter, 'scatter_{}'.format(op))
value = scatter(
value, inv, dim=0, dim_size=perm.size(0), fill_value=fill_value)
return index, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment