README.md 5.51 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
[pypi-image]: https://badge.fury.io/py/torch-scatter.svg
[pypi-url]: https://pypi.python.org/pypi/torch-scatter
rusty1s's avatar
rusty1s committed
3
4
5
6
[testing-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml/badge.svg
[testing-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml
[linting-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml/badge.svg
[linting-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml
rusty1s's avatar
rusty1s committed
7
8
[docs-image]: https://readthedocs.org/projects/pytorch-scatter/badge/?version=latest
[docs-url]: https://pytorch-scatter.readthedocs.io/en/latest/?badge=latest
rusty1s's avatar
rusty1s committed
9
10
[coverage-image]: https://codecov.io/gh/rusty1s/pytorch_scatter/branch/master/graph/badge.svg
[coverage-url]: https://codecov.io/github/rusty1s/pytorch_scatter?branch=master
rusty1s's avatar
rusty1s committed
11

rusty1s's avatar
rusty1s committed
12
13
14
# PyTorch Scatter

[![PyPI Version][pypi-image]][pypi-url]
rusty1s's avatar
rusty1s committed
15
16
[![Testing Status][testing-image]][testing-url]
[![Linting Status][linting-image]][linting-url]
rusty1s's avatar
rusty1s committed
17
[![Docs Status][docs-image]][docs-url]
rusty1s's avatar
rusty1s committed
18
[![Code Coverage][coverage-image]][coverage-url]
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
<p align="center">
rusty1s's avatar
smaller  
rusty1s committed
21
  <img width="50%" src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" />
rusty1s's avatar
rusty1s committed
22
23
24
25
</p>

--------------------------------------------------------------------------------

rusty1s's avatar
rusty1s committed
26
**[Documentation](https://pytorch-scatter.readthedocs.io)**
rusty1s's avatar
rusty1s committed
27

rusty1s's avatar
rusty1s committed
28
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
rusty1s's avatar
rusty1s committed
29
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
rusty1s's avatar
rusty1s committed
30
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
rusty1s's avatar
rusty1s committed
31

rusty1s's avatar
rusty1s committed
32
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
33

Sérgio Agostinho's avatar
Sérgio Agostinho committed
34
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html) based on arbitrary indices
rusty1s's avatar
rusty1s committed
35
36
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
37

rusty1s's avatar
typo  
rusty1s committed
38
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: `scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
rusty1s's avatar
rusty1s committed
41

rusty1s's avatar
typo  
rusty1s committed
42
## Installation
rusty1s's avatar
rusty1s committed
43

rusty1s's avatar
rusty1s committed
44
45
### Anaconda

rusty1s's avatar
rusty1s committed
46
**Update:** You can now install `pytorch-scatter` via [Anaconda](https://anaconda.org/pyg/pytorch-scatter) for all major OS/PyTorch/CUDA combinations 🤗
rusty1s's avatar
rusty1s committed
47
48
49
Given that you have [`pytorch >= 1.8.0` installed](https://pytorch.org/get-started/locally/), simply run

```
rusty1s's avatar
rusty1s committed
50
conda install pytorch-scatter -c pyg
rusty1s's avatar
rusty1s committed
51
52
```

rusty1s's avatar
rusty1s committed
53
54
### Binaries

rusty1s's avatar
rusty1s committed
55
We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
rusty1s committed
57
#### PyTorch 1.10.0
rusty1s's avatar
rusty1s committed
58

rusty1s's avatar
rusty1s committed
59
To install the binaries for PyTorch 1.10.0, simply run
rusty1s's avatar
rusty1s committed
60
61

```
rusty1s's avatar
rusty1s committed
62
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html
rusty1s's avatar
rusty1s committed
63
64
```

rusty1s's avatar
rusty1s committed
65
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu113` depending on your PyTorch installation.
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
|             | `cpu` | `cu102` | `cu113` |
rusty1s's avatar
rusty1s committed
68
69
70
71
|-------------|-------|---------|---------|
| **Linux**   | ✅    | ✅      | ✅      |
| **Windows** | ✅    | ✅      | ✅      |
| **macOS**   | ✅    |         |         |
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
#### PyTorch 1.9.0/1.9.1
rusty1s's avatar
rusty1s committed
74

rusty1s's avatar
rusty1s committed
75
To install the binaries for PyTorch 1.9.0 and 1.9.1, simply run
rusty1s's avatar
rusty1s committed
76
77

```
rusty1s's avatar
rusty1s committed
78
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+${CUDA}.html
rusty1s's avatar
rusty1s committed
79
80
```

rusty1s's avatar
rusty1s committed
81
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu111` depending on your PyTorch installation.
rusty1s's avatar
rusty1s committed
82

rusty1s's avatar
rusty1s committed
83
84
85
86
87
|             | `cpu` | `cu102` | `cu111` |
|-------------|-------|---------|---------|
| **Linux**   | ✅    | ✅      | ✅      |
| **Windows** | ✅    | ✅      | ✅      |
| **macOS**   | ✅    |         |         |
rusty1s's avatar
rusty1s committed
88

rusty1s's avatar
rusty1s committed
89
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1 and PyTorch 1.8.0/1.8.1 (following the same procedure).
rusty1s's avatar
rusty1s committed
90
91

### From source
rusty1s's avatar
rusty1s committed
92

rusty1s's avatar
rusty1s committed
93
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
rusty1s's avatar
rusty1s committed
94
95
96

```
$ python -c "import torch; print(torch.__version__)"
rusty1s's avatar
rusty1s committed
97
>>> 1.4.0
rusty1s's avatar
rusty1s committed
98
99
100
101
102

$ echo $PATH
>>> /usr/local/cuda/bin:...

$ echo $CPATH
rusty1s's avatar
rusty1s committed
103
>>> /usr/local/cuda/include:...
rusty1s's avatar
rusty1s committed
104
105
```

rusty1s's avatar
typo  
rusty1s committed
106
Then run:
rusty1s's avatar
typo  
rusty1s committed
107
108
109
110
111

```
pip install torch-scatter
```

rusty1s's avatar
rusty1s committed
112
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
rusty1s's avatar
rusty1s committed
113
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
114

rusty1s's avatar
rusty1s committed
115
116
117
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
rusty1s committed
119
## Example
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
```py
rusty1s's avatar
typo  
rusty1s committed
122
import torch
rusty1s's avatar
rusty1s committed
123
124
from torch_scatter import scatter_max

rusty1s's avatar
rusty1s committed
125
126
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
rusty1s's avatar
rusty1s committed
127

rusty1s's avatar
rusty1s committed
128
out, argmax = scatter_max(src, index, dim=-1)
rusty1s's avatar
typo  
rusty1s committed
129
```
rusty1s's avatar
rusty1s committed
130

rusty1s's avatar
typo  
rusty1s committed
131
```
rusty1s's avatar
rusty1s committed
132
print(out)
rusty1s's avatar
rusty1s committed
133
134
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])
rusty1s's avatar
rusty1s committed
135

rusty1s's avatar
typos  
rusty1s committed
136
print(argmax)
rusty1s's avatar
rusty1s committed
137
138
tensor([[5, 5, 3, 4, 0, 1]
        [1, 4, 3, 5, 5, 5]])
rusty1s's avatar
typos  
rusty1s committed
139
```
rusty1s's avatar
rusty1s committed
140
141
142

## Running tests

rusty1s's avatar
rusty1s committed
143
```
Matthias Fey's avatar
Matthias Fey committed
144
pytest
rusty1s's avatar
rusty1s committed
145
```
rusty1s's avatar
rusty1s committed
146
147
148
149
150
151
152
153
154
155
156
157
158

## C++ API

`torch-scatter` also offers a C++ API that contains C++ equivalent of python models.

```
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
```