README.md 3.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Stochastic Training for Graph Convolutional Networks

* Paper: [Control Variate](https://arxiv.org/abs/1710.10568)
* Paper: [Skip Connection](https://arxiv.org/abs/1809.05343)
* Author's code: [https://github.com/thu-ml/stochastic_gcn](https://github.com/thu-ml/stochastic_gcn)

### Dependencies

- MXNet nightly build

```bash
pip install mxnet --pre
```

### Neighbor Sampling & Skip Connection
cora: test accuracy ~83% with `--num-neighbors 2`, ~84% by training on the full graph
```
Chao Ma's avatar
Chao Ma committed
18
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_ns --dataset cora --self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 5000
19
20
21
22
```

citeseer: test accuracy ~69% with `--num-neighbors 2`, ~70% by training on the full graph
```
Chao Ma's avatar
Chao Ma committed
23
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_ns --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 5000
24
25
```

26
pubmed: test accuracy ~78% with `--num-neighbors 3`, ~77% by training on the full graph
27
```
Chao Ma's avatar
Chao Ma committed
28
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_ns --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000 --test-batch-size 5000
29
30
```

31
reddit: test accuracy ~91% with `--num-neighbors 3` and `--batch-size 1000`, ~93% by training on the full graph
32
```
Chao Ma's avatar
Chao Ma committed
33
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_ns --dataset reddit-self-loop --num-neighbors 3 --batch-size 1000 --test-batch-size 5000 --n-hidden 64
34
35
36
37
38
39
```


### Control Variate & Skip Connection
cora: test accuracy ~84% with `--num-neighbors 1`, ~84% by training on the full graph
```
Chao Ma's avatar
Chao Ma committed
40
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_cv --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
41
42
43
44
```

citeseer: test accuracy ~69% with `--num-neighbors 1`, ~70% by training on the full graph
```
Chao Ma's avatar
Chao Ma committed
45
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_cv --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
46
47
```

48
pubmed: test accuracy ~79% with `--num-neighbors 1`, ~77% by training on the full graph
49
```
Chao Ma's avatar
Chao Ma committed
50
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_cv --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
51
52
53
54
```

reddit: test accuracy ~93% with `--num-neighbors 1` and `--batch-size 1000`, ~93% by training on the full graph
```
Chao Ma's avatar
Chao Ma committed
55
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model gcn_cv --dataset reddit-self-loop --num-neighbors 1 --batch-size 10000 --test-batch-size 5000 --n-hidden 64
56
57
58
59
60
61
62
63
```

### Control Variate & GraphSAGE-mean

Following [Control Variate](https://arxiv.org/abs/1710.10568), we use the mean pooling architecture GraphSAGE-mean, two linear layers and layer normalization per graph convolution layer.

reddit: test accuracy 96.1% with `--num-neighbors 1` and `--batch-size 1000`, ~96.2% in [Control Variate](https://arxiv.org/abs/1710.10568) with `--num-neighbors 2` and `--batch-size 1000`
```
Chao Ma's avatar
Chao Ma committed
64
DGLBACKEND=mxnet python3 examples/mxnet/sampling/train.py --model graphsage_cv --batch-size 1000 --test-batch-size 5000 --n-epochs 50 --dataset reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
65
66
```

67
68
69
70
71
72
73
74
75
76
77
### Run multi-processing training

Run the graph store server that loads the reddit dataset with four workers.
```
python3 examples/mxnet/sampling/run_store_server.py --dataset reddit --num-workers 4
```

Run four workers to train GraphSage on the reddit dataset.
```
python3 ../incubator-mxnet/tools/launch.py -n 4 -s 1 --launcher local python3 examples/mxnet/sampling/multi_process_train.py --model graphsage_cv --batch-size 1000 --test-batch-size 5000 --n-epochs 1 --graph-name reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
```