README.md 5.25 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
[pypi-image]: https://badge.fury.io/py/torch-scatter.svg
[pypi-url]: https://pypi.python.org/pypi/torch-scatter
rusty1s's avatar
rusty1s committed
3
4
5
6
[testing-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml/badge.svg
[testing-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml
[linting-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml/badge.svg
[linting-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml
rusty1s's avatar
rusty1s committed
7
8
[docs-image]: https://readthedocs.org/projects/pytorch-scatter/badge/?version=latest
[docs-url]: https://pytorch-scatter.readthedocs.io/en/latest/?badge=latest
rusty1s's avatar
rusty1s committed
9
10
[coverage-image]: https://codecov.io/gh/rusty1s/pytorch_scatter/branch/master/graph/badge.svg
[coverage-url]: https://codecov.io/github/rusty1s/pytorch_scatter?branch=master
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
13
14
# PyTorch Scatter

[![PyPI Version][pypi-image]][pypi-url]
rusty1s's avatar
rusty1s committed
15
16
[![Testing Status][testing-image]][testing-url]
[![Linting Status][linting-image]][linting-url]
rusty1s's avatar
rusty1s committed
17
[![Docs Status][docs-image]][docs-url]
rusty1s's avatar
rusty1s committed
18
[![Code Coverage][coverage-image]][coverage-url]
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
<p align="center">
rusty1s's avatar
smaller  
rusty1s committed
21
  <img width="50%" src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" />
rusty1s's avatar
rusty1s committed
22
23
24
25
</p>

--------------------------------------------------------------------------------

rusty1s's avatar
rusty1s committed
26
**[Documentation](https://pytorch-scatter.readthedocs.io)**
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
rusty1s's avatar
rusty1s committed
29
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
rusty1s's avatar
rusty1s committed
30
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
33

Sérgio Agostinho's avatar
Sérgio Agostinho committed
34
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html) based on arbitrary indices
rusty1s's avatar
rusty1s committed
35
36
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
37

rusty1s's avatar
typo  
rusty1s committed
38
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: `scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
typo  
rusty1s committed
42
## Installation
rusty1s's avatar
rusty1s committed
43
44
45

### Binaries

rusty1s's avatar
rusty1s committed
46
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://pytorch-geometric.com/whl).
rusty1s's avatar
rusty1s committed
47

rusty1s's avatar
rusty1s committed
48
#### PyTorch 1.9.0
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
To install the binaries for PyTorch 1.9.0, simply run
rusty1s's avatar
rusty1s committed
51
52

```
rusty1s's avatar
rusty1s committed
53
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+${CUDA}.html
rusty1s's avatar
rusty1s committed
54
55
```

rusty1s's avatar
rusty1s committed
56
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu111` depending on your PyTorch installation.
rusty1s's avatar
rusty1s committed
57

rusty1s's avatar
rusty1s committed
58
59
60
61
62
|             | `cpu` | `cu102` | `cu111` |
|-------------|-------|---------|---------|
| **Linux**   | ✅    | ✅      | ✅      |
| **Windows** | ✅    | ✅      | ✅      |
| **macOS**   | ✅    |         |         |
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
#### PyTorch 1.8.0/1.8.1
rusty1s's avatar
rusty1s committed
65

rusty1s's avatar
rusty1s committed
66
To install the binaries for PyTorch 1.8.0 and 1.8.1, simply run
rusty1s's avatar
rusty1s committed
67
68

```
rusty1s's avatar
rusty1s committed
69
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html
rusty1s's avatar
rusty1s committed
70
71
```

rusty1s's avatar
rusty1s committed
72
where `${CUDA}` should be replaced by either `cpu`, `cu101`, `cu102`, or `cu111` depending on your PyTorch installation.
rusty1s's avatar
rusty1s committed
73

rusty1s's avatar
rusty1s committed
74
75
76
|             | `cpu` | `cu101` | `cu102` | `cu111` |
|-------------|-------|---------|---------|---------|
| **Linux**   | ✅    | ✅      | ✅      | ✅      |
rusty1s's avatar
rusty1s committed
77
| **Windows** | ✅    | ❌      | ✅      | ✅      |
rusty1s's avatar
rusty1s committed
78
| **macOS**   | ✅    |         |         |         |
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0 and PyTorch 1.7.0/1.7.1 (following the same procedure).
rusty1s's avatar
rusty1s committed
81
82

### From source
rusty1s's avatar
rusty1s committed
83

rusty1s's avatar
rusty1s committed
84
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
rusty1s's avatar
rusty1s committed
85
86
87

```
$ python -c "import torch; print(torch.__version__)"
rusty1s's avatar
rusty1s committed
88
>>> 1.4.0
rusty1s's avatar
rusty1s committed
89
90
91
92
93

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
rusty1s's avatar
rusty1s committed
94
>>> /usr/local/cuda/include:...
rusty1s's avatar
rusty1s committed
95
96
```

rusty1s's avatar
typo  
rusty1s committed
97
Then run:
rusty1s's avatar
typo  
rusty1s committed
98
99
100
101
102

```
pip install torch-scatter
```

rusty1s's avatar
rusty1s committed
103
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
rusty1s's avatar
rusty1s committed
104
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
105

rusty1s's avatar
rusty1s committed
106
107
108
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
rusty1s's avatar
rusty1s committed
109

rusty1s's avatar
rusty1s committed
110
## Example
rusty1s's avatar
rusty1s committed
111

rusty1s's avatar
rusty1s committed
112
```py
rusty1s's avatar
typo  
rusty1s committed
113
import torch
rusty1s's avatar
rusty1s committed
114
115
from torch_scatter import scatter_max

rusty1s's avatar
rusty1s committed
116
117
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
rusty1s committed
119
out, argmax = scatter_max(src, index, dim=-1)
rusty1s's avatar
typo  
rusty1s committed
120
```
rusty1s's avatar
rusty1s committed
121

rusty1s's avatar
typo  
rusty1s committed
122
```
rusty1s's avatar
rusty1s committed
123
print(out)
rusty1s's avatar
rusty1s committed
124
125
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])
rusty1s's avatar
rusty1s committed
126

rusty1s's avatar
typos  
rusty1s committed
127
print(argmax)
rusty1s's avatar
rusty1s committed
128
129
tensor([[5, 5, 3, 4, 0, 1]
        [1, 4, 3, 5, 5, 5]])
rusty1s's avatar
typos  
rusty1s committed
130
```
rusty1s's avatar
rusty1s committed
131
132
133

## Running tests

rusty1s's avatar
rusty1s committed
134
```
rusty1s's avatar
rusty1s committed
135
136
python setup.py test
```
rusty1s's avatar
rusty1s committed
137
138
139
140
141
142
143
144
145
146
147
148
149

## C++ API

`torch-scatter` also offers a C++ API that contains C++ equivalent of python models.

```
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
```