README.md 2.91 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
<h1 align="center">PyGAS: Auto-Scaling GNNs in PyG</h1>

<img width="100%" src="https://raw.githubusercontent.com/rusty1s/pyg_autoscale/master/figures/overview.png?token=ABU7ZAXZ7WT3RIOSYHIDIVDAEI3SY" />

--------------------------------------------------------------------------------
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
*PyGAS* is the practical realization of our *<ins>G</ins>NN<ins>A</ins>uto<ins>S</ins>cale* (GAS) framework, which scales arbitrary message-passing GNNs to large graphs.
rusty1s's avatar
typos  
rusty1s committed
8
9
GAS prunes entire sub-trees of the computation graph by utilizing historical embeddings from prior training iterations, leading to constant GPU memory consumption in respect to input node size without dropping any data.
As a result, our approach is provably able to maintain the expressive power of the original GNN.
rusty1s's avatar
rusty1s committed
10
11
12
13

*PyGAS* is implemented in [PyTorch](https://pytorch.org/) and utilizes the [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric) (PyG) library.
It provides an easy-to-use interface to convert common and custom GNNs from PyG into its scalable variant:

rusty1s's avatar
rusty1s committed
14
15
16
17
```python
from torch_geometric.nn import GCNConv
from torch_geometric_autoscale import ScalableGNN

rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
class GNN(ScalableGNN):
    def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers):
        super(GNN, self).__init__(num_nodes, hidden_channels, num_layers)

        self.convs = ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))

rusty1s's avatar
typos  
rusty1s committed
29
    def forward(self, x, adj_t, batch_size, n_id):
rusty1s's avatar
rusty1s committed
30
        for conv, history in zip(self.convs[:-1], self.histories):
rusty1s's avatar
typos  
rusty1s committed
31
            x = conv(x, adj_t).relu_()
rusty1s's avatar
rusty1s committed
32
            x = self.push_and_pull(history, x, batch_size, n_id)
rusty1s's avatar
typos  
rusty1s committed
33
        return self.convs[-1](x, adj_t)
rusty1s's avatar
rusty1s committed
34
```
rusty1s's avatar
rusty1s committed
35
36
37

## Installation

rusty1s's avatar
rusty1s committed
38
39
* Install [**PyTorch >= 1.7.0**](https://pytorch.org/get-started/locally/)
* Install [**PyTorch Geometric**](https://github.com/rusty1s/pytorch_geometric#pytorch-170171) from **master**:
rusty1s's avatar
rusty1s committed
40
41
42
43
44
45
46
47
48

```
pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
pip install git+https://github.com/rusty1s/pytorch_geometric.git
```

where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu101`, `cu102`, or `cu110` depending on your PyTorch installation.

rusty1s's avatar
rusty1s committed
49
Then, run:
rusty1s's avatar
rusty1s committed
50
51
52
53
54

```
python setup.py install
```

rusty1s's avatar
rusty1s committed
55
## Project Structure
rusty1s's avatar
rusty1s committed
56

rusty1s's avatar
typos  
rusty1s committed
57
58
* **`torch_geometric_autoscale/`** contains the source code of *PyGAS*
* **`examples/`** contains examples to demonstrate how to apply GAS in practice
rusty1s's avatar
typos  
rusty1s committed
59
60
* **`small_benchmark/`** includes experiments to evaluate GAS performance on *small-scale* graphs
* **`large_benchmark/`** includes experiments to evaluate GAS performance on *large-scale* graphs
rusty1s's avatar
rusty1s committed
61

rusty1s's avatar
typos  
rusty1s committed
62
We use [**Hydra**](https://hydra.cc/) to manage hyperparameter configurations.