README.md 5.04 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
[pypi-image]: https://badge.fury.io/py/torch-scatter.svg
[pypi-url]: https://pypi.python.org/pypi/torch-scatter
rusty1s's avatar
rusty1s committed
3
4
[build-image]: https://travis-ci.org/rusty1s/pytorch_scatter.svg?branch=master
[build-url]: https://travis-ci.org/rusty1s/pytorch_scatter
rusty1s's avatar
rusty1s committed
5
6
[docs-image]: https://readthedocs.org/projects/pytorch-scatter/badge/?version=latest
[docs-url]: https://pytorch-scatter.readthedocs.io/en/latest/?badge=latest
rusty1s's avatar
rusty1s committed
7
8
[coverage-image]: https://codecov.io/gh/rusty1s/pytorch_scatter/branch/master/graph/badge.svg
[coverage-url]: https://codecov.io/github/rusty1s/pytorch_scatter?branch=master
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
11
12
13
# PyTorch Scatter

[![PyPI Version][pypi-image]][pypi-url]
[![Build Status][build-image]][build-url]
rusty1s's avatar
rusty1s committed
14
[![Docs Status][docs-image]][docs-url]
rusty1s's avatar
rusty1s committed
15
[![Code Coverage][coverage-image]][coverage-url]
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
<p align="center">
rusty1s's avatar
smaller  
rusty1s committed
18
  <img width="50%" src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" />
rusty1s's avatar
rusty1s committed
19
20
21
22
</p>

--------------------------------------------------------------------------------

rusty1s's avatar
rusty1s committed
23
**[Documentation](https://pytorch-scatter.readthedocs.io)**
rusty1s's avatar
rusty1s committed
24

rusty1s's avatar
rusty1s committed
25
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
rusty1s's avatar
rusty1s committed
26
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
rusty1s's avatar
rusty1s committed
27
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
30

Sérgio Agostinho's avatar
Sérgio Agostinho committed
31
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html) based on arbitrary indices
rusty1s's avatar
rusty1s committed
32
33
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
34

rusty1s's avatar
typo  
rusty1s committed
35
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: `scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
rusty1s's avatar
rusty1s committed
36

rusty1s's avatar
rusty1s committed
37
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
typo  
rusty1s committed
39
## Installation
rusty1s's avatar
rusty1s committed
40
41
42

### Binaries

rusty1s's avatar
rusty1s committed
43
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://pytorch-geometric.com/whl).
rusty1s's avatar
rusty1s committed
44

rusty1s's avatar
rusty1s committed
45
#### PyTorch 1.7.0
rusty1s's avatar
rusty1s committed
46

rusty1s's avatar
rusty1s committed
47
To install the binaries for PyTorch 1.7.0, simply run
rusty1s's avatar
rusty1s committed
48
49

```
rusty1s's avatar
rusty1s committed
50
pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.7.0.html
rusty1s's avatar
rusty1s committed
51
52
```

rusty1s's avatar
rusty1s committed
53
where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu101`, `cu102`, or `cu110` depending on your PyTorch installation.
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
56
57
58
59
|             | `cpu` | `cu92` | `cu101` | `cu102` | `cu110` |
|-------------|-------|--------|---------|---------|---------|
| **Linux**   | ✅    | ✅     | ✅      | ✅      | ✅      |
| **Windows** | ✅    | ❌     | ✅      | ✅      | ✅      |
| **macOS**   | ✅    |        |         |         |         |
rusty1s's avatar
rusty1s committed
60

rusty1s's avatar
rusty1s committed
61
62
63
#### PyTorch 1.6.0

To install the binaries for PyTorch 1.6.0, simply run
rusty1s's avatar
rusty1s committed
64
65

```
rusty1s's avatar
rusty1s committed
66
pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.6.0.html
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
73
74
75
76
```

where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu101` or `cu102` depending on your PyTorch installation.

|             | `cpu` | `cu92` | `cu101` | `cu102` |
|-------------|-------|--------|---------|---------|
| **Linux**   | ✅    | ✅     | ✅      | ✅      |
| **Windows** | ✅    | ❌     | ✅      | ✅      |
| **macOS**   | ✅    |        |         |         |

rusty1s's avatar
rusty1s committed
77
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0 and PyTorch 1.5.0 (following the same procedure).
rusty1s's avatar
rusty1s committed
78
79

### From source
rusty1s's avatar
rusty1s committed
80

rusty1s's avatar
rusty1s committed
81
Ensure that at least PyTorch 1.5.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
rusty1s's avatar
rusty1s committed
82
83
84

```
$ python -c "import torch; print(torch.__version__)"
rusty1s's avatar
rusty1s committed
85
>>> 1.5.0
rusty1s's avatar
rusty1s committed
86
87
88
89
90

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
rusty1s's avatar
rusty1s committed
91
>>> /usr/local/cuda/include:...
rusty1s's avatar
rusty1s committed
92
93
```

rusty1s's avatar
typo  
rusty1s committed
94
Then run:
rusty1s's avatar
typo  
rusty1s committed
95
96
97
98
99

```
pip install torch-scatter
```

rusty1s's avatar
rusty1s committed
100
101
When running in a docker container without nvidia driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
102

rusty1s's avatar
rusty1s committed
103
104
105
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
rusty1s's avatar
rusty1s committed
106

rusty1s's avatar
rusty1s committed
107
## Example
rusty1s's avatar
rusty1s committed
108

rusty1s's avatar
rusty1s committed
109
```py
rusty1s's avatar
typo  
rusty1s committed
110
import torch
rusty1s's avatar
rusty1s committed
111
112
from torch_scatter import scatter_max

rusty1s's avatar
rusty1s committed
113
114
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
rusty1s's avatar
rusty1s committed
115

rusty1s's avatar
rusty1s committed
116
out, argmax = scatter_max(src, index, dim=-1)
rusty1s's avatar
typo  
rusty1s committed
117
```
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
typo  
rusty1s committed
119
```
rusty1s's avatar
rusty1s committed
120
print(out)
rusty1s's avatar
rusty1s committed
121
122
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])
rusty1s's avatar
rusty1s committed
123

rusty1s's avatar
typos  
rusty1s committed
124
print(argmax)
rusty1s's avatar
rusty1s committed
125
126
tensor([[5, 5, 3, 4, 0, 1]
        [1, 4, 3, 5, 5, 5]])
rusty1s's avatar
typos  
rusty1s committed
127
```
rusty1s's avatar
rusty1s committed
128
129
130

## Running tests

rusty1s's avatar
rusty1s committed
131
```
rusty1s's avatar
rusty1s committed
132
133
python setup.py test
```
rusty1s's avatar
rusty1s committed
134
135
136
137
138
139
140
141
142
143
144
145
146

## C++ API

`torch-scatter` also offers a C++ API that contains C++ equivalent of python models.

```
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
```