Commit a44bec9e authored by rusty1s's avatar rusty1s
Browse files

description

parent 0ac78d87
......@@ -4,10 +4,18 @@
--------------------------------------------------------------------------------
*PyGAS* is the practical realization of our *<ins>G</ins>NN<ins>A</ins>uto<ins>S</ins>cale* (GAS) framework, which scales arbitrary message-passing GNNs to large graphs.
*PyGAS* allows for training and inference of GNNs with a constant GPU memory footprint, while it does not drop any input data in comparison to related scalability approaches.
Our approach based on historical node embeddings is provably able to keep the existing expressiveness properties of the underlying message passing implementation.
*PyGAS* is implemented in [PyTorch](https://pytorch.org/) and utilizes the [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric) (PyG) library.
It provides an easy-to-use interface to convert common and custom GNNs from PyG into its scalable variant:
```python
from torch_geometric.nn import GCNConv
from torch_geometric_autoscale import ScalableGNN
class GNN(ScalableGNN):
def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers):
super(GNN, self).__init__(num_nodes, hidden_channels, num_layers)
......@@ -18,9 +26,13 @@ class GNN(ScalableGNN):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))
def forward(self, x, adj, n_id):
def forward(self, x, adj, batch_size, n_id):
for conv, history in zip(self.convs[:-1], self.histories):
x = conv(x, adj).relu()
x = self.push_and_pull(history, x, n_id)
x = conv(x, adj).relu_()
x = self.push_and_pull(history, x, batch_size, n_id)
return self.convs[-1](x, adj
```
## Installation
## Project Structure
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment