README.md 3.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Stochastic Training for Graph Convolutional Networks

* Paper: [Control Variate](https://arxiv.org/abs/1710.10568)
* Paper: [Skip Connection](https://arxiv.org/abs/1809.05343)
* Author's code: [https://github.com/thu-ml/stochastic_gcn](https://github.com/thu-ml/stochastic_gcn)

### Dependencies

- MXNet nightly build

```bash
pip install mxnet --pre
```

### Neighbor Sampling & Skip Connection
cora: test accuracy ~83% with `--num-neighbors 2`, ~84% by training on the full graph
```
18
DGLBACKEND=mxnet python3 train.py --model gcn_ns --dataset cora --self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 5000
19
20
21
22
```

citeseer: test accuracy ~69% with `--num-neighbors 2`, ~70% by training on the full graph
```
23
DGLBACKEND=mxnet python3 train.py --model gcn_ns --dataset citeseer --self-loop --num-neighbors 2 --batch-size 1000 --test-batch-size 5000
24
25
```

26
pubmed: test accuracy ~78% with `--num-neighbors 3`, ~77% by training on the full graph
27
```
28
DGLBACKEND=mxnet python3 train.py --model gcn_ns --dataset pubmed --self-loop --num-neighbors 3 --batch-size 1000 --test-batch-size 5000
29
30
```

31
reddit: test accuracy ~91% with `--num-neighbors 3` and `--batch-size 1000`, ~93% by training on the full graph
32
```
33
DGLBACKEND=mxnet python3 train.py --model gcn_ns --dataset reddit-self-loop --num-neighbors 3 --batch-size 1000 --test-batch-size 5000 --n-hidden 64
34
35
36
37
38
39
```


### Control Variate & Skip Connection
cora: test accuracy ~84% with `--num-neighbors 1`, ~84% by training on the full graph
```
40
DGLBACKEND=mxnet python3 train.py --model gcn_cv --dataset cora --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
41
42
43
44
```

citeseer: test accuracy ~69% with `--num-neighbors 1`, ~70% by training on the full graph
```
45
DGLBACKEND=mxnet python3 train.py --model gcn_cv --dataset citeseer --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
46
47
```

48
pubmed: test accuracy ~79% with `--num-neighbors 1`, ~77% by training on the full graph
49
```
50
DGLBACKEND=mxnet python3 train.py --model gcn_cv --dataset pubmed --self-loop --num-neighbors 1 --batch-size 1000000 --test-batch-size 1000000
51
52
53
54
```

reddit: test accuracy ~93% with `--num-neighbors 1` and `--batch-size 1000`, ~93% by training on the full graph
```
55
DGLBACKEND=mxnet python3 train.py --model gcn_cv --dataset reddit-self-loop --num-neighbors 1 --batch-size 10000 --test-batch-size 5000 --n-hidden 64
56
57
58
59
60
61
62
63
```

### Control Variate & GraphSAGE-mean

Following [Control Variate](https://arxiv.org/abs/1710.10568), we use the mean pooling architecture GraphSAGE-mean, two linear layers and layer normalization per graph convolution layer.

reddit: test accuracy 96.1% with `--num-neighbors 1` and `--batch-size 1000`, ~96.2% in [Control Variate](https://arxiv.org/abs/1710.10568) with `--num-neighbors 2` and `--batch-size 1000`
```
64
DGLBACKEND=mxnet python3 train.py --model graphsage_cv --batch-size 1000 --test-batch-size 5000 --n-epochs 50 --dataset reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
65
66
```

67
68
### Run multi-processing training

Da Zheng's avatar
Da Zheng committed
69
70
71
When training a GNN model with multiple processes, there are two steps.

Step 1: run a graph store server separately that loads the reddit dataset with four workers.
72
73
74
75
```
python3 examples/mxnet/sampling/run_store_server.py --dataset reddit --num-workers 4
```

Da Zheng's avatar
Da Zheng committed
76
Step 2: run four workers to train GraphSage on the reddit dataset.
77
```
Da Zheng's avatar
Da Zheng committed
78
python3 ../incubator-mxnet/tools/launch.py -n 4 -s 1 --launcher local python3 multi_process_train.py --model graphsage_cv --batch-size 2500 --test-batch-size 5000 --n-epochs 1 --graph-name reddit --num-neighbors 1 --n-hidden 128 --dropout 0.2 --weight-decay 0
79
```