Unverified Commit aef96dfa authored by K's avatar K Committed by GitHub
Browse files

[Model] Refine GraphSAINT (#3328)



* The start of experiments of Jiahang Li on GraphSAINT.

* a nightly build

* a nightly build

Check the basic pipeline of codes. Next to check the details of samplers , GCN layer (forward propagation) and loss (backward propagation)

* a night build

* Implement GraphSAINT with torch.dataloader

There're still some bugs with sampling in training procedure

* Test validity

Succeed in testing validity on ppi_node experiments without testing other setup.
1. Online sampling on ppi_node experiments performs perfectly.
2. Sampling speed is a bit slow because the operations on [dgl.subgraphs], next step is to improve this part by putting the conversion into parallelism
3. Figuring out why offline+online sampling method performs bad, which does not make sense
4. Doing experiments on other setup

* Implement saint with torch.dataloader

Use torch.dataloader to speed up saint sampling with experiments. Except experiments on too large dataset Amazon, we've done some experiments on other four datasets including ppi, flickr, reddit and yelp. Preliminary experimental results show consumed time and metrics reach not bad level. Next step is to employ more accurate profiler which is the line_profiler to test consumed period, and adjust num_workers to speed up sampling procedures on same certain datasets faster.

* a nightly build

* Update .gitignore

* reorganize codes

Reorganize some codes and comments.

* a nightly build

* Update .gitignore

* fix bugs

Fix bugs about why fully offline sampling and author's version don't work

* reorganize files and codes

Reorganize files and codes then do some experiments to test the performance of offline sampling and online sampling

* do some experiments and update README

* a nightly build

* a nightly build

* Update README.md

* delete unnecessary files

* Update README.md

* a nightly update

1. handle directory named 'graphsaintdata'
2. control graph shift between gpu and cpu related to large dataset ('amazon')
3. remove parameter 'train'
4. refine annotations of the sampler
5. update README.md including updating dataset info, dependencies info, etc

* a nightly update

explain config differences in TEST part
remove a sampling time variant
make 'online' an argument
change 'norm' to 'sampler'
explain parameters in README.md

* Update README.md

* a nightly build

* make online an argument
* refine README.md
* refine codes of `collate_fn` in sampler.py, in training phase only return one subgraph, no need to check if the number of subgraphs larger than 1

* Update sampler.py

check the problem on flickr is about overfitting.

* a nightly update

Fix the overfitting problem of `flickr` dataset. We need to restrict the number of subgraphs (also the number of iterations) used in each epoch of training phase. Or it might overfit when validating at the end of each epoch. The method to limit the number is a formula specified by the author.

* Set up a new flag `full` specifying if the number of subgraphs used in training phase equals to that of pre-sampled subgraphs

* Modify codes and annotations related the new flag

* Add a new parameter called `node_budget` in the base class `SAINTSampler` to compute the specific formula

* set `gpu` as a command line argument

* Update README.md

* Finish the experiments on Flickr, which is done after adding new flag `full`

* a nightly update

* use half of edges in the original graph to do sampling
* test dgl.random.choice with or without replacement with half of edges
~ next is to test what if put the calculating probability part out of __getitem__ can speed up sampling and try to implement sampling method of author

* employ cython to implement edge sampling for per edge

* employ cython to implement edge sampling for per edge
* doing experiments to test consumed time and performance
** the consumed time decreased to approximately 480s, the performance decrease about 5 points.
* deprecate cython implementation

* Revert "employ cython to implement edge sampling for per edge"

* This reverts commit 4ba4f092
* Deprecate cython implementation
* Reserve half-edges mechanism

* a nightly update

* delete unnecessary annotations
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent f9fd7fd7
...@@ -161,3 +161,4 @@ cscope.* ...@@ -161,3 +161,4 @@ cscope.*
config.cmake config.cmake
.ycm_extra_conf.py .ycm_extra_conf.py
**.png **.png
...@@ -6,94 +6,163 @@ Paper link: https://arxiv.org/abs/1907.04931 ...@@ -6,94 +6,163 @@ Paper link: https://arxiv.org/abs/1907.04931
Author's code: https://github.com/GraphSAINT/GraphSAINT Author's code: https://github.com/GraphSAINT/GraphSAINT
Contributor: Liu Tang ([@lt610](https://github.com/lt610)) Contributor: Jiahang Li ([@ljh1064126026](https://github.com/ljh1064126026)) Tang Liu ([@lt610](https://github.com/lt610))
## Dependencies ## Dependencies
- Python 3.7.0 - Python 3.7.10
- PyTorch 1.6.0 - PyTorch 1.8.1
- NumPy 1.19.2 - NumPy 1.19.2
- Scikit-learn 0.23.2 - Scikit-learn 0.23.2
- DGL 0.5.3 - DGL 0.7.1
## Dataset ## Dataset
All datasets used are provided by Author's [code](https://github.com/GraphSAINT/GraphSAINT). They are available in [Google Drive](https://drive.google.com/drive/folders/1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [Baidu Wangpan (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg#list/path=%2F)). Once you download the datasets, you need to rename graphsaintdata to data. Dataset summary("m" stands for multi-label classification, and "s" for single-label.): All datasets used are provided by Author's [code](https://github.com/GraphSAINT/GraphSAINT). They are available in [Google Drive](https://drive.google.com/drive/folders/1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [Baidu Wangpan (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg#list/path=%2F)). Dataset summary("m" stands for multi-label binary classification, and "s" for single-label.):
| Dataset | Nodes | Edges | Degree | Feature | Classes | Train/Val/Test | | Dataset | Nodes | Edges | Degree | Feature | Classes |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | | :-: | :-: | :-: | :-: | :-: | :-: |
| PPI | 14,755 | 225,270 | 15 | 50 | 121(m) | 0.66/0.12/0.22 | | PPI | 14,755 | 225,270 | 15 | 50 | 121(m) |
| Flickr | 89,250 | 899,756 | 10 | 500 | 7(s) | 0.50/0.25/0.25 | | Flickr | 89,250 | 899,756 | 10 | 500 | 7(s) |
| Reddit | 232,965 | 11,606,919 | 50 | 602 | 41(s) |
| Yelp | 716,847 | 6,977,410 | 10 | 300 | 100 (m) |
| Amazon | 1,598,960 | 132,169,734 | 83 | 200 | 107 (m) |
Note that the PPI dataset here is different from DGL's built-in variant. Note that the PPI dataset here is different from DGL's built-in variant.
## Config
- The config file is `config.py`, which contains best configs for experiments below.
- Please refer to `sampler.py` to see explanations of some key parameters.
### Parameters
| **aggr** | **arch** | **dataset** | **dropout** |
| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| define how to aggregate embeddings of each node and its neighbors' embeddings ,which can be 'concat', 'mean'. The neighbors' embeddings are generated based on GCN | e.g. '1-1-0', means there're three layers, the first and the second layer employ message passing on the graph, then aggregate the embeddings of each node and its neighbors. The last layer only updates each node's embedding. The message passing mechanism comes from GCN | the name of dataset, which can be 'ppi', 'flickr', 'reddit', 'yelp', 'amazon' | the dropout of model used in train_sampling.py |
| **edge_budget** | **gpu** | **length** | **log_dir** |
| the expected number of edges in each subgraph, which is specified in the paper | -1 means cpu, otherwise 'cuda:gpu', e.g. if gpu=0, use 'cuda:0' | the length of each random walk | the directory storing logs |
| **lr** | **n_epochs** | **n_hidden** | **no_batch_norm** |
| learning rate | training epochs | hidden dimension | True if do NOT employ batch normalization in each layer |
| **node_budget** | **num_subg** | **num_roots** | **sampler** |
| the expected number of nodes in each subgraph, which is specified in the paper | the expected number of pre_sampled subgraphs | the number of roots to generate random walks | specify which sampler to use, which can be 'node', 'edge', 'rw', corresponding to node, edge, random walk sampler |
| **use_val** | **val_every** | **num_workers_sampler** | **num_subg_sampler** |
| True if use best model to test, which is stored by earlystop mechanism | validate per 'val_every' epochs | number of workers (processes) specified for internal dataloader in SAINTSampler, which is to pre-sample subgraphs | the maximal number of pre-sampled subgraphs |
| **batch_size_sampler** | **num_workers** | | |
| batch size of internal dataloader in SAINTSampler | number of workers (processes) specified for external dataloader in train_sampling.py, which is to sample subgraphs in training phase | | |
## Minibatch training ## Minibatch training
Run with following: Run with following:
```bash ```bash
python train_sampling.py --gpu 0 --dataset ppi --sampler node --node-budget 6000 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 python train_sampling.py --task $task $online
python train_sampling.py --gpu 0 --dataset ppi --sampler edge --edge-budget 4000 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 --dropout 0.1 # online sampling: e.g. python train_sampling.py --task ppi_n --online
python train_sampling.py --gpu 0 --dataset ppi --sampler rw --num-roots 3000 --length 2 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 --dropout 0.1 # offline sampling: e.g. python train_sampling.py --task flickr_e
python train_sampling.py --gpu 0 --dataset flickr --sampler node --node-budget 8000 --num-repeat 25 --n-epochs 30 --n-hidden 256 --arch 1-1-0 --dropout 0.2
python train_sampling.py --gpu 0 --dataset flickr --sampler edge --edge-budget 6000 --num-repeat 25 --n-epochs 15 --n-hidden 256 --arch 1-1-0 --dropout 0.2
python train_sampling.py --gpu 0 --dataset flickr --sampler rw --num-roots 6000 --length 2 --num-repeat 25 --n-epochs 15 --n-hidden 256 --arch 1-1-0 --dropout 0.2
``` ```
## Comparison - `$task` includes `ppi_n, ppi_e, ppi_rw, flickr_n, flickr_e, flickr_rw, reddit_n, reddit_e, reddit_rw, yelp_n, yelp_e, yelp_rw, amazon_n, amazon_e, amazon_rw`. For example, `ppi_n` represents running experiments on dataset `ppi` with `node sampler`
- If `$online` is `--online`, we sample subgraphs on-the-fly in the training phase, while discarding pre-sampled subgraphs. If `$online` is empty, we utilize pre-sampled subgraphs in the training phase.
## Experiments
* Paper: results from the paper * Paper: results from the paper
* Running: results from experiments with the authors' code * Running: results from experiments with the authors' code
* DGL: results from experiments with the DGL example * DGL: results from experiments with the DGL example. The experiment config comes from `config.py`. You can modify parameters in the `config.py` to see different performance of different setup.
> Note that we implement offline sampling and online sampling in training phase. Offline sampling means all subgraphs utilized in training phase come from pre-sampled subgraphs. Online sampling means we discard all pre-sampled subgraphs and re-sample new subgraphs in training phase.
> Note that the sampling method in the pre-sampling phase must be offline sampling.
### F1-micro ### F1-micro
#### Random node sampler #### Random node sampler
| Method | PPI | Flickr | | Method | PPI | Flickr | Reddit | Yelp | Amazon |
| --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| Paper | 0.960±0.001 | 0.507±0.001 | | Paper | 0.960±0.001 | 0.507±0.001 | 0.962±0.001 | 0.641±0.000 | 0.782±0.004 |
| Running | 0.9628 | 0.5077 | | Running | 0.9628 | 0.5077 | 0.9622 | 0.6393 | 0.7695 |
| DGL | 0.9618 | 0.4828 | | DGL_offline | 0.9715 | 0.5024 | 0.9645 | 0.6457 | 0.8051 |
| DGL_online | 0.9730 | 0.5071 | 0.9645 | 0.6444 | 0.8014 |
#### Random edge sampler #### Random edge sampler
| Method | PPI | Flickr | | Method | PPI | Flickr | Reddit | Yelp | Amazon |
| --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| Paper | 0.981±0.007 | 0.510±0.002 | | Paper | 0.981±0.007 | 0.510±0.002 | 0.966±0.001 | 0.653±0.003 | 0.807±0.001 |
| Running | 0.9810 | 0.5066 | | Running | 0.9810 | 0.5066 | 0.9656 | 0.6531 | 0.8071 |
| DGL | 0.9818 | 0.5054 | | DGL_offline | 0.9817 | 0.5077 | 0.9655 | 0.6530 | 0.8034 |
| DGL_online | 0.9815 | 0.5041 | 0.9653 | 0.6516 | 0.7756 |
#### Random walk sampler #### Random walk sampler
| Method | PPI | Flickr | | Method | PPI | Flickr | Reddit | Yelp | Amazon |
| --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| Paper | 0.981±0.004 | 0.511±0.001 | | Paper | 0.981±0.004 | 0.511±0.001 | 0.966±0.001 | 0.653±0.003 | 0.815±0.001 |
| Running | 0.9812 | 0.5104 | | Running | 0.9812 | 0.5104 | 0.9648 | 0.6527 | 0.8131 |
| DGL | 0.9818 | 0.5018 | | DGL_offline | 0.9833 | 0.5027 | 0.9582 | 0.6514 | 0.8178 |
| DGL_online | 0.9820 | 0.5110 | 0.9572 | 0.6508 | 0.8157 |
### Sampling time ### Sampling time
- Here sampling time includes consumed time of pre-sampling subgraphs and calculating normalization coefficients in the beginning.
#### Random node sampler #### Random node sampler
| Method | PPI | Flickr | | Method | PPI | Flickr | Reddit | Yelp | Amazon |
| --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| Sampling(Running) | 0.77 | 0.65 | | Running | 1.46 | 3.49 | 19 | 59.01 | 978.62 |
| Sampling(DGL) | 0.24 | 0.57 | | DGL | 2.51 | 1.12 | 27.32 | 60.15 | 929.24 |
| Normalization(Running) | 0.69 | 2.84 |
| Normalization(DGL) | 1.04 | 0.41 |
#### Random edge sampler #### Random edge sampler
| Method | PPI | Flickr | | Method | PPI | Flickr | Reddit | Yelp | Amazon |
| --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| Sampling(Running) | 0.72 | 0.56 | | Running | 1.4 | 3.18 | 13.88 | 39.02 | |
| Sampling(DGL) | 0.50 | 0.72 | | DGL | 3.04 | 1.87 | 52.01 | 48.38 | |
| Normalization(Running) | 0.68 | 2.62 |
| Normalization(DGL) | 0.61 | 0.38 |
#### Random walk sampler #### Random walk sampler
| Method | PPI | Flickr | | Method | PPI | Flickr | Reddit | Yelp | Amazon |
| --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| Sampling(Running) | 0.83 | 1.22 | | Running | 1.7 | 3.82 | 16.97 | 43.25 | 355.68 |
| Sampling(DGL) | 0.28 | 0.63 | | DGL | 3.05 | 2.13 | 11.01 | 22.23 | 151.84 |
| Normalization(Running) | 0.87 | 2.60 |
| Normalization(DGL) | 0.70 | 0.42 | ## Test std of sampling and normalization time
- We've run experiments 10 times repeatedly to test average and standard deviation of sampling and normalization time. Here we just test time without training model to the end. Moreover, for efficient testing, the hardware and config employed here are not the same as the experiments above, so the sampling time might be a bit different from that above. But we keep the environment consistent in all experiments below.
> The config here which is different with that in the section above is only `num_workers_sampler`, `batch_size_sampler` and `num_workers`, which are only correlated to the sampling speed. Other parameters are kept consistent across two sections thus the model's performance is not affected.
> The value is (average, std).
### Random node sampler
| Method | PPI | Flickr | Reddit | Yelp | Amazon |
| ------------------------- | --------------- | ------------ | ------------- | ------------- | --------------- |
| DGL_Sampling(std) | 2.618, 0.004 | 3.017, 0.507 | 35.356, 2.363 | 69.913, 6.3 | 888.025, 16.004 |
| DGL_Normalization(std) | Small to ignore | 0.008, 0.004 | 0.26, 0.047 | 0.189, 0.0288 | 2.443, 0.124 |
| | | | | | |
| author_Sampling(std) | 0.788, 0.661 | 0.728, 0.367 | 8.931, 3.155 | 27.818, 1.384 | 295.597, 4.928 |
| author_Normalization(std) | 0.665, 0.565 | 4.981, 2.952 | 17.231, 7.116 | 47.449, 2.794 | 279.241, 17.615 |
### Random edge sampler
| Method | PPI | Flickr | Reddit | Yelp | Amazon |
| ------------------------- | --------------- | ------------ | ------------- | ------------- | ------ |
| DGL_Sampling(std) | 3.554, 0.292 | 4.722, 0.245 | 47.09, 2.76 | 75.219, 6.442 | |
| DGL_Normalization(std) | Small to ignore | 0.005, 0.007 | 0.235, 0.026 | 0.193, 0.021 | |
| | | | | | |
| author_Sampling(std) | 0.802, 0.667 | 0.761, 0.387 | 6.058, 2.166 | 13.914, 1.864 | |
| author_Normalization(std) | 0.667, 0.570 | 5.180, 3.006 | 15.803, 5.867 | 44.278, 5.853 | |
### Random walk sampler
| Method | PPI | Flickr | Reddit | Yelp | Amazon |
| ------------------------- | --------------- | ------------ | ------------- | ------------- | --------------- |
| DGL_Sampling(std) | 3.304, 0.08 | 5.487, 1.294 | 37.041, 2.083 | 39.951, 3.094 | 179.613, 18.881 |
| DGL_Normalization(std) | Small to ignore | 0.001, 0.003 | 0.235, 0.026 | 0.185, 0.018 | 3.769, 0.326 |
| | | | | | |
| author_Sampling(std) | 0.924, 0.773 | 1.405, 0.718 | 8.608, 3.093 | 19.113, 1.700 | 217.184, 1.546 |
| author_Normalization(std) | 0.701, 0.596 | 5.025, 2.954 | 18.198, 7.223 | 45.874, 8.020 | 128.272, 3.170 |
CONFIG={
'ppi_n':
{
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0, 'edge_budget': 4000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'ppi_e':
{
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0.1, 'edge_budget': 4000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'ppi_rw':
{
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'ppi', 'dropout': 0.1, 'edge_budget': 4000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 6000,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'flickr_n':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'flickr', 'dropout': 0.2, 'edge_budget': 6000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 256, 'no_batch_norm': False, 'node_budget': 8000,
'num_subg': 25, 'num_roots': 6000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': False
},
'flickr_e':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'flickr', 'dropout': 0.2, 'edge_budget': 6000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 256, 'no_batch_norm': False, 'node_budget': 8000,
'num_subg': 25, 'num_roots': 6000, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': False
},
'flickr_rw':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'flickr', 'dropout': 0.2, 'edge_budget': 6000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 50, 'n_hidden': 256, 'no_batch_norm': False, 'node_budget': 8000,
'num_subg': 25, 'num_roots': 6000, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 0,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': False
},
'reddit_n':
{
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'reddit', 'dropout': 0.1, 'edge_budget': 4000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 20, 'n_hidden': 128, 'no_batch_norm': False, 'node_budget': 8000,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'reddit_e':
{
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'reddit', 'dropout': 0.1, 'edge_budget': 6000, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 20, 'n_hidden': 128, 'no_batch_norm': False, 'node_budget': 8000,
'num_subg': 50, 'num_roots': 3000, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'reddit_rw':
{
'aggr': 'concat', 'arch': '1-0-1-0', 'dataset': 'reddit', 'dropout': 0.1, 'edge_budget': 6000, 'length': 4,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 128, 'no_batch_norm': False, 'node_budget': 8000,
'num_subg': 50, 'num_roots': 200, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'yelp_n':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'yelp', 'dropout': 0.1, 'edge_budget': 6000, 'length': 4,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000,
'num_subg': 50, 'num_roots': 200, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'yelp_e':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'yelp', 'dropout': 0.1, 'edge_budget': 2500, 'length': 4,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000,
'num_subg': 50, 'num_roots': 200, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'yelp_rw':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'yelp', 'dropout': 0.1, 'edge_budget': 2500, 'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000,
'num_subg': 50, 'num_roots': 1250, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 8,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'amazon_n':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'amazon', 'dropout': 0.1, 'edge_budget': 2500, 'length': 4,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 5, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 4500,
'num_subg': 50, 'num_roots': 200, 'sampler': 'node', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 4,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
},
'amazon_e':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'amazon', 'dropout': 0.1, 'edge_budget': 2000, 'gpu': 0,'length': 4,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 10, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000,
'num_subg': 50, 'num_roots': 200, 'sampler': 'edge', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 20,
'num_subg_sampler': 5000, 'batch_size_sampler': 50, 'num_workers': 26, 'full': True
},
'amazon_rw':
{
'aggr': 'concat', 'arch': '1-1-0', 'dataset': 'amazon', 'dropout': 0.1, 'edge_budget': 2500, 'gpu': 0,'length': 2,
'log_dir': 'none', 'lr': 0.01, 'n_epochs': 5, 'n_hidden': 512, 'no_batch_norm': False, 'node_budget': 5000,
'num_subg': 50, 'num_roots': 1500, 'sampler': 'rw', 'use_val': True, 'val_every': 1, 'num_workers_sampler': 4,
'num_subg_sampler': 10000, 'batch_size_sampler': 200, 'num_workers': 8, 'full': True
}
}
...@@ -32,7 +32,7 @@ class GCNLayer(nn.Module): ...@@ -32,7 +32,7 @@ class GCNLayer(nn.Module):
for lin in self.lins: for lin in self.lins:
nn.init.xavier_normal_(lin.weight) nn.init.xavier_normal_(lin.weight)
def feat_trans(self, features, idx): def feat_trans(self, features, idx): # linear transformation + activation + batch normalization
h = self.lins[idx](features) + self.bias[idx] h = self.lins[idx](features) + self.bias[idx]
if self.act is not None: if self.act is not None:
...@@ -51,7 +51,7 @@ class GCNLayer(nn.Module): ...@@ -51,7 +51,7 @@ class GCNLayer(nn.Module):
h_hop = [h_in] h_hop = [h_in]
D_norm = g.ndata['train_D_norm'] if 'train_D_norm' in g.ndata else g.ndata['full_D_norm'] D_norm = g.ndata['train_D_norm'] if 'train_D_norm' in g.ndata else g.ndata['full_D_norm']
for _ in range(self.order): for _ in range(self.order): # forward propagation
g.ndata['h'] = h_hop[-1] g.ndata['h'] = h_hop[-1]
if 'w' not in g.edata: if 'w' not in g.edata:
g.edata['w'] = th.ones((g.num_edges(), )).to(features.device) g.edata['w'] = th.ones((g.num_edges(), )).to(features.device)
......
import math
import os import os
import time import time
import math
import torch as th import torch as th
from torch.utils.data import DataLoader
import random import random
import numpy as np import numpy as np
import dgl.function as fn import dgl.function as fn
import dgl import dgl
from dgl.sampling import random_walk, pack_traces from dgl.sampling import random_walk, pack_traces
import scipy
# The base class of sampler # The base class of sampler
# (TODO): online sampling class SAINTSampler:
class SAINTSampler(object): """
def __init__(self, dn, g, train_nid, node_budget, num_repeat=50): Description
""" -----------
:param dn: name of dataset SAINTSampler implements the sampler described in GraphSAINT. This sampler implements offline sampling in
:param g: full graph pre-sampling phase as well as fully offline sampling, fully online sampling in training phase.
:param train_nid: ids of training nodes Users can conveniently set param 'online' of the sampler to choose different modes.
:param node_budget: expected number of sampled nodes
:param num_repeat: number of times of repeating sampling one node Parameters
""" ----------
self.g = g node_budget : int
the expected number of nodes in each subgraph, which is specifically explained in the paper. Actually this
param specifies the times of sampling nodes from the original graph with replacement. The meaning of edge_budget
is similar to the node_budget.
dn : str
name of dataset.
g : DGLGraph
the full graph.
train_nid : list
ids of training nodes.
num_workers_sampler : int
number of processes to sample subgraphs in pre-sampling procedure using torch.dataloader.
num_subg_sampler : int, optional
the max number of subgraphs sampled in pre-sampling phase for computing normalization coefficients in the beginning.
Actually this param is used as ``__len__`` of sampler in pre-sampling phase.
Please make sure that num_subg_sampler is greater than batch_size_sampler so that we can sample enough subgraphs.
Defaults: 10000
batch_size_sampler : int, optional
the number of subgraphs sampled by each process concurrently in pre-sampling phase.
Defaults: 200
online : bool, optional
If `True`, we employ online sampling in training phase. Otherwise employing offline sampling.
Defaults: True
num_subg : int, optional
the expected number of sampled subgraphs in pre-sampling phase.
It is actually the 'N' in the original paper. Note that this param is different from the num_subg_sampler.
This param is just used to control the number of pre-sampled subgraphs.
Defaults: 50
full : bool, optional
True if the number of subgraphs used in the training phase equals to that of pre-sampled subgraphs, or
``math.ceil(self.train_g.num_nodes() / self.node_budget)``. This formula takes the result of A divided by B as
the number of subgraphs used in the training phase, where A is the number of training nodes in the original
graph, B is the expected number of nodes in each pre-sampled subgraph. Please refer to the paper to check the
details.
Defaults: True
Notes
-----
For parallelism of pre-sampling, we utilize `torch.DataLoader` to concurrently speed up sampling.
The `num_subg_sampler` is the return value of `__len__` in pre-sampling phase. Moreover, the param `batch_size_sampler`
determines the batch_size of `torch.DataLoader` in internal pre-sampling part. But note that if we wanna pass the
SAINTSampler to `torch.DataLoader` for concurrently sampling subgraphs in training phase, we need to specify
`batch_size` of `DataLoader`, that is, `batch_size_sampler` is not related to how sampler works in training procedure.
"""
def __init__(self, node_budget, dn, g, train_nid, num_workers_sampler, num_subg_sampler=10000,
batch_size_sampler=200, online=True, num_subg=50, full=True):
self.g = g.cpu()
self.node_budget = node_budget
self.train_g: dgl.graph = g.subgraph(train_nid) self.train_g: dgl.graph = g.subgraph(train_nid)
self.dn, self.num_repeat = dn, num_repeat self.dn, self.num_subg = dn, num_subg
self.node_counter = th.zeros((self.train_g.num_nodes(),)) self.node_counter = th.zeros((self.train_g.num_nodes(),))
self.edge_counter = th.zeros((self.train_g.num_edges(),)) self.edge_counter = th.zeros((self.train_g.num_edges(),))
self.prob = None self.prob = None
self.num_subg_sampler = num_subg_sampler
self.batch_size_sampler = batch_size_sampler
self.num_workers_sampler = num_workers_sampler
self.train = False
self.online = online
self.full = full
assert self.num_subg_sampler >= self.batch_size_sampler, "num_subg_sampler should be greater than batch_size_sampler"
graph_fn, norm_fn = self.__generate_fn__() graph_fn, norm_fn = self.__generate_fn__()
if os.path.exists(graph_fn): if os.path.exists(graph_fn):
...@@ -37,31 +94,86 @@ class SAINTSampler(object): ...@@ -37,31 +94,86 @@ class SAINTSampler(object):
self.subgraphs = [] self.subgraphs = []
self.N, sampled_nodes = 0, 0 self.N, sampled_nodes = 0, 0
# N: the number of pre-sampled subgraphs
# Employ parallelism to speed up the sampling procedure
loader = DataLoader(self, batch_size=self.batch_size_sampler, shuffle=True,
num_workers=self.num_workers_sampler, collate_fn=self.__collate_fn__, drop_last=False)
t = time.perf_counter() t = time.perf_counter()
while sampled_nodes <= self.train_g.num_nodes() * num_repeat: for num_nodes, subgraphs_nids, subgraphs_eids in loader:
subgraph = self.__sample__()
self.subgraphs.append(subgraph) self.subgraphs.extend(subgraphs_nids)
sampled_nodes += subgraph.shape[0] sampled_nodes += num_nodes
self.N += 1
_subgraphs, _node_counts = np.unique(np.concatenate(subgraphs_nids), return_counts=True)
sampled_nodes_idx = th.from_numpy(_subgraphs)
_node_counts = th.from_numpy(_node_counts)
self.node_counter[sampled_nodes_idx] += _node_counts
_subgraphs_eids, _edge_counts = np.unique(np.concatenate(subgraphs_eids), return_counts=True)
sampled_edges_idx = th.from_numpy(_subgraphs_eids)
_edge_counts = th.from_numpy(_edge_counts)
self.edge_counter[sampled_edges_idx] += _edge_counts
self.N += len(subgraphs_nids) # number of subgraphs
if sampled_nodes > self.train_g.num_nodes() * num_subg:
break
print(f'Sampling time: [{time.perf_counter() - t:.2f}s]') print(f'Sampling time: [{time.perf_counter() - t:.2f}s]')
np.save(graph_fn, self.subgraphs) np.save(graph_fn, self.subgraphs)
t = time.perf_counter() t = time.perf_counter()
self.__counter__()
aggr_norm, loss_norm = self.__compute_norm__() aggr_norm, loss_norm = self.__compute_norm__()
print(f'Normalization time: [{time.perf_counter() - t:.2f}s]') print(f'Normalization time: [{time.perf_counter() - t:.2f}s]')
np.save(norm_fn, (aggr_norm, loss_norm)) np.save(norm_fn, (aggr_norm, loss_norm))
self.train_g.ndata['l_n'] = th.Tensor(loss_norm) self.train_g.ndata['l_n'] = th.Tensor(loss_norm)
self.train_g.edata['w'] = th.Tensor(aggr_norm) self.train_g.edata['w'] = th.Tensor(aggr_norm)
self.__compute_degree_norm() self.__compute_degree_norm() # basically normalizing adjacent matrix
self.num_batch = math.ceil(self.train_g.num_nodes() / node_budget)
random.shuffle(self.subgraphs) random.shuffle(self.subgraphs)
self.__clear__() self.__clear__()
print("The number of subgraphs is: ", len(self.subgraphs)) print("The number of subgraphs is: ", len(self.subgraphs))
print("The size of subgraphs is about: ", len(self.subgraphs[-1]))
self.train = True
def __len__(self):
if self.train is False:
return self.num_subg_sampler
else:
if self.full:
return len(self.subgraphs)
else:
return math.ceil(self.train_g.num_nodes() / self.node_budget)
def __getitem__(self, idx):
# Only when sampling subgraphs in training procedure and need to utilize sampled subgraphs and we still
# have sampled subgraphs we can fetch a subgraph from sampled subgraphs
if self.train:
if self.online:
subgraph = self.__sample__()
return dgl.node_subgraph(self.train_g, subgraph)
else:
return dgl.node_subgraph(self.train_g, self.subgraphs[idx])
else:
subgraph_nids = self.__sample__()
num_nodes = len(subgraph_nids)
subgraph_eids = dgl.node_subgraph(self.train_g, subgraph_nids).edata[dgl.EID]
return num_nodes, subgraph_nids, subgraph_eids
def __collate_fn__(self, batch):
if self.train: # sample only one graph each epoch, batch_size in training phase in 1
return batch[0]
else:
sum_num_nodes = 0
subgraphs_nids_list = []
subgraphs_eids_list = []
for num_nodes, subgraph_nids, subgraph_eids in batch:
sum_num_nodes += num_nodes
subgraphs_nids_list.append(subgraph_nids)
subgraphs_eids_list.append(subgraph_eids)
return sum_num_nodes, subgraphs_nids_list, subgraphs_eids_list
def __clear__(self): def __clear__(self):
self.prob = None self.prob = None
...@@ -69,20 +181,11 @@ class SAINTSampler(object): ...@@ -69,20 +181,11 @@ class SAINTSampler(object):
self.edge_counter = None self.edge_counter = None
self.g = None self.g = None
def __counter__(self):
for sampled_nodes in self.subgraphs:
sampled_nodes = th.from_numpy(sampled_nodes)
self.node_counter[sampled_nodes] += 1
subg = self.train_g.subgraph(sampled_nodes)
sampled_edges = subg.edata[dgl.EID]
self.edge_counter[sampled_edges] += 1
def __generate_fn__(self): def __generate_fn__(self):
raise NotImplementedError raise NotImplementedError
def __compute_norm__(self): def __compute_norm__(self):
self.node_counter[self.node_counter == 0] = 1 self.node_counter[self.node_counter == 0] = 1
self.edge_counter[self.edge_counter == 0] = 1 self.edge_counter[self.edge_counter == 0] = 1
...@@ -106,33 +209,28 @@ class SAINTSampler(object): ...@@ -106,33 +209,28 @@ class SAINTSampler(object):
def __sample__(self): def __sample__(self):
raise NotImplementedError raise NotImplementedError
def __len__(self):
return self.num_batch
def __iter__(self):
self.n = 0
return self
def __next__(self):
if self.n < self.num_batch:
result = self.train_g.subgraph(self.subgraphs[self.n])
self.n += 1
return result
else:
random.shuffle(self.subgraphs)
raise StopIteration()
class SAINTNodeSampler(SAINTSampler): class SAINTNodeSampler(SAINTSampler):
def __init__(self, node_budget, dn, g, train_nid, num_repeat=50): """
Description
-----------
GraphSAINT with node sampler.
Parameters
----------
node_budget : int
the expected number of nodes in each subgraph, which is specifically explained in the paper.
"""
def __init__(self, node_budget, **kwargs):
self.node_budget = node_budget self.node_budget = node_budget
super(SAINTNodeSampler, self).__init__(dn, g, train_nid, node_budget, num_repeat) super(SAINTNodeSampler, self).__init__(node_budget=node_budget, **kwargs)
def __generate_fn__(self): def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_Node_{}_{}.npy'.format(self.dn, self.node_budget, graph_fn = os.path.join('./subgraphs/{}_Node_{}_{}.npy'.format(self.dn, self.node_budget,
self.num_repeat)) self.num_subg))
norm_fn = os.path.join('./subgraphs/{}_Node_{}_{}_norm.npy'.format(self.dn, self.node_budget, norm_fn = os.path.join('./subgraphs/{}_Node_{}_{}_norm.npy'.format(self.dn, self.node_budget,
self.num_repeat)) self.num_subg))
return graph_fn, norm_fn return graph_fn, norm_fn
def __sample__(self): def __sample__(self):
...@@ -144,48 +242,83 @@ class SAINTNodeSampler(SAINTSampler): ...@@ -144,48 +242,83 @@ class SAINTNodeSampler(SAINTSampler):
class SAINTEdgeSampler(SAINTSampler): class SAINTEdgeSampler(SAINTSampler):
def __init__(self, edge_budget, dn, g, train_nid, num_repeat=50): """
Description
-----------
GraphSAINT with edge sampler.
Parameters
----------
edge_budget : int
the expected number of edges in each subgraph, which is specifically explained in the paper.
"""
def __init__(self, edge_budget, **kwargs):
self.edge_budget = edge_budget self.edge_budget = edge_budget
super(SAINTEdgeSampler, self).__init__(dn, g, train_nid, edge_budget * 2, num_repeat) self.rng = np.random.default_rng()
super(SAINTEdgeSampler, self).__init__(node_budget=edge_budget*2, **kwargs)
def __generate_fn__(self): def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_Edge_{}_{}.npy'.format(self.dn, self.edge_budget, graph_fn = os.path.join('./subgraphs/{}_Edge_{}_{}.npy'.format(self.dn, self.edge_budget,
self.num_repeat)) self.num_subg))
norm_fn = os.path.join('./subgraphs/{}_Edge_{}_{}_norm.npy'.format(self.dn, self.edge_budget, norm_fn = os.path.join('./subgraphs/{}_Edge_{}_{}_norm.npy'.format(self.dn, self.edge_budget,
self.num_repeat)) self.num_subg))
return graph_fn, norm_fn return graph_fn, norm_fn
# TODO: only sample half edges, then add another half edges
# TODO: use numpy to implement cython sampling method
def __sample__(self): def __sample__(self):
if self.prob is None: if self.prob is None:
src, dst = self.train_g.edges() src, dst = self.train_g.edges()
src_degrees, dst_degrees = self.train_g.in_degrees(src).float().clamp(min=1),\ src_degrees, dst_degrees = self.train_g.in_degrees(src).float().clamp(min=1), \
self.train_g.in_degrees(dst).float().clamp(min=1) self.train_g.in_degrees(dst).float().clamp(min=1)
self.prob = 1. / src_degrees + 1. / dst_degrees prob_mat = 1. / src_degrees + 1. / dst_degrees
prob_mat = scipy.sparse.csr_matrix((prob_mat.numpy(), (src.numpy(), dst.numpy())))
# The edge probability here only contains that of edges in upper triangle adjacency matrix
# Because we assume the graph is undirected, that is, the adjacency matrix is symmetric. We only need
# to consider half of edges in the graph.
self.prob = th.tensor(scipy.sparse.triu(prob_mat).data)
self.prob /= self.prob.sum()
self.adj_nodes = np.stack(prob_mat.nonzero(), axis=1)
sampled_edges = np.unique(
dgl.random.choice(len(self.prob), size=self.edge_budget, prob=self.prob, replace=False)
)
sampled_nodes = np.unique(self.adj_nodes[sampled_edges].flatten()).astype('long')
return sampled_nodes
sampled_edges = th.multinomial(self.prob, num_samples=self.edge_budget, replacement=True).unique()
sampled_src, sampled_dst = self.train_g.find_edges(sampled_edges) class SAINTRandomWalkSampler(SAINTSampler):
sampled_nodes = th.cat([sampled_src, sampled_dst]).unique() """
return sampled_nodes.numpy() Description
-----------
GraphSAINT with random walk sampler
Parameters
----------
num_roots : int
the number of roots to generate random walks.
length : int
the length of each random walk.
class SAINTRandomWalkSampler(SAINTSampler): """
def __init__(self, num_roots, length, dn, g, train_nid, num_repeat=50):
def __init__(self, num_roots, length, **kwargs):
self.num_roots, self.length = num_roots, length self.num_roots, self.length = num_roots, length
super(SAINTRandomWalkSampler, self).__init__(dn, g, train_nid, num_roots * length, num_repeat) super(SAINTRandomWalkSampler, self).__init__(node_budget=num_roots * length, **kwargs)
def __generate_fn__(self): def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots, graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots,
self.length, self.num_repeat)) self.length, self.num_subg))
norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots, norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots,
self.length, self.num_repeat)) self.length, self.num_subg))
return graph_fn, norm_fn return graph_fn, norm_fn
def __sample__(self): def __sample__(self):
sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots, )) sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots,))
traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length) traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length)
sampled_nodes, _, _, _ = pack_traces(traces, types) sampled_nodes, _, _, _ = pack_traces(traces, types)
sampled_nodes = sampled_nodes.unique() sampled_nodes = sampled_nodes.unique()
return sampled_nodes.numpy() return sampled_nodes.numpy()
import argparse import argparse
import os import os
import time import time
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader
from sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler from sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler
from config import CONFIG
from modules import GCNNet from modules import GCNNet
from utils import Logger, evaluate, save_log_dir, load_data from utils import Logger, evaluate, save_log_dir, load_data, calc_f1
import warnings
def main(args, task):
def main(args): warnings.filterwarnings('ignore')
multilabel_data = {'ppi', 'yelp', 'amazon'}
multilabel_data = set(['ppi'])
multilabel = args.dataset in multilabel_data multilabel = args.dataset in multilabel_data
# This flag is excluded for too large dataset, like amazon, the graph of which is too large to be directly
# shifted to one gpu. So we need to
# 1. put the whole graph on cpu, and put the subgraphs on gpu in training phase
# 2. put the model on gpu in training phase, and put the model on cpu in validation/testing phase
# We need to judge cpu_flag and cuda (below) simultaneously when shift model between cpu and gpu
if args.dataset in ['amazon']:
cpu_flag = True
else:
cpu_flag = False
# load and preprocess dataset # load and preprocess dataset
data = load_data(args, multilabel) data = load_data(args, multilabel)
g = data.g g = data.g
...@@ -45,16 +56,23 @@ def main(args): ...@@ -45,16 +56,23 @@ def main(args):
n_val_samples, n_val_samples,
n_test_samples)) n_test_samples))
# load sampler # load sampler
kwargs = {
'dn': args.dataset, 'g': g, 'train_nid': train_nid, 'num_workers_sampler': args.num_workers_sampler,
'num_subg_sampler': args.num_subg_sampler, 'batch_size_sampler': args.batch_size_sampler,
'online': args.online, 'num_subg': args.num_subg, 'full': args.full
}
if args.sampler == "node": if args.sampler == "node":
subg_iter = SAINTNodeSampler(args.node_budget, args.dataset, g, saint_sampler = SAINTNodeSampler(args.node_budget, **kwargs)
train_nid, args.num_repeat)
elif args.sampler == "edge": elif args.sampler == "edge":
subg_iter = SAINTEdgeSampler(args.edge_budget, args.dataset, g, saint_sampler = SAINTEdgeSampler(args.edge_budget, **kwargs)
train_nid, args.num_repeat)
elif args.sampler == "rw": elif args.sampler == "rw":
subg_iter = SAINTRandomWalkSampler(args.num_roots, args.length, args.dataset, g, saint_sampler = SAINTRandomWalkSampler(args.num_roots, args.length, **kwargs)
train_nid, args.num_repeat) else:
raise NotImplementedError
loader = DataLoader(saint_sampler, collate_fn=saint_sampler.__collate_fn__, batch_size=1,
shuffle=True, num_workers=args.num_workers, drop_last=False)
# set device for dataset tensors # set device for dataset tensors
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
...@@ -63,7 +81,8 @@ def main(args): ...@@ -63,7 +81,8 @@ def main(args):
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
val_mask = val_mask.cuda() val_mask = val_mask.cuda()
test_mask = test_mask.cuda() test_mask = test_mask.cuda()
g = g.to(args.gpu) if not cpu_flag:
g = g.to('cuda:{}'.format(args.gpu))
print('labels shape:', g.ndata['label'].shape) print('labels shape:', g.ndata['label'].shape)
print("features shape:", g.ndata['feat'].shape) print("features shape:", g.ndata['feat'].shape)
...@@ -99,8 +118,7 @@ def main(args): ...@@ -99,8 +118,7 @@ def main(args):
best_f1 = -1 best_f1 = -1
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for j, subg in enumerate(subg_iter): for j, subg in enumerate(loader):
# sync with upper level training graph
if cuda: if cuda:
subg = subg.to(torch.cuda.current_device()) subg = subg.to(torch.cuda.current_device())
model.train() model.train()
...@@ -119,12 +137,20 @@ def main(args): ...@@ -119,12 +137,20 @@ def main(args):
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 5) torch.nn.utils.clip_grad_norm(model.parameters(), 5)
optimizer.step() optimizer.step()
if j == len(subg_iter) - 1:
print(f"epoch:{epoch+1}/{args.n_epochs}, Iteration {j+1}/"
f"{len(subg_iter)}:training loss", loss.item())
if j == len(loader) - 1:
model.eval()
with torch.no_grad():
train_f1_mic, train_f1_mac = calc_f1(batch_labels.cpu().numpy(),
pred.cpu().numpy(), multilabel)
print(f"epoch:{epoch + 1}/{args.n_epochs}, Iteration {j + 1}/"
f"{len(loader)}:training loss", loss.item())
print("Train F1-mic {:.4f}, Train F1-mac {:.4f}".format(train_f1_mic, train_f1_mac))
# evaluate # evaluate
model.eval()
if epoch % args.val_every == 0: if epoch % args.val_every == 0:
if cpu_flag and cuda: # Only when we have shifted model to gpu and we need to shift it back on cpu
model = model.to('cpu')
val_f1_mic, val_f1_mac = evaluate( val_f1_mic, val_f1_mac = evaluate(
model, g, labels, val_mask, multilabel) model, g, labels, val_mask, multilabel)
print( print(
...@@ -133,7 +159,9 @@ def main(args): ...@@ -133,7 +159,9 @@ def main(args):
best_f1 = val_f1_mic best_f1 = val_f1_mic
print('new best val f1:', best_f1) print('new best val f1:', best_f1)
torch.save(model.state_dict(), os.path.join( torch.save(model.state_dict(), os.path.join(
log_dir, 'best_model.pkl')) log_dir, 'best_model_{}.pkl'.format(task)))
if cpu_flag and cuda:
model.cuda()
end_time = time.time() end_time = time.time()
print(f'training using time {end_time - start_time}') print(f'training using time {end_time - start_time}')
...@@ -141,63 +169,24 @@ def main(args): ...@@ -141,63 +169,24 @@ def main(args):
# test # test
if args.use_val: if args.use_val:
model.load_state_dict(torch.load(os.path.join( model.load_state_dict(torch.load(os.path.join(
log_dir, 'best_model.pkl'))) log_dir, 'best_model_{}.pkl'.format(task))))
if cpu_flag and cuda:
model = model.to('cpu')
test_f1_mic, test_f1_mac = evaluate( test_f1_mic, test_f1_mac = evaluate(
model, g, labels, test_mask, multilabel) model, g, labels, test_mask, multilabel)
print("Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(test_f1_mic, test_f1_mac)) print("Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(test_f1_mic, test_f1_mac))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GraphSAINT') warnings.filterwarnings('ignore')
# data source params
parser.add_argument("--dataset", type=str, choices=['ppi', 'flickr'], default='ppi',
help="Name of dataset.")
# cuda params
parser.add_argument("--gpu", type=int, default=-1,
help="GPU index. Default: -1, using CPU.")
# sampler params
parser.add_argument("--sampler", type=str, default="node", choices=['node', 'edge', 'rw'],
help="Type of sampler")
parser.add_argument("--node-budget", type=int, default=6000,
help="Expected number of sampled nodes when using node sampler")
parser.add_argument("--edge-budget", type=int, default=4000,
help="Expected number of sampled edges when using edge sampler")
parser.add_argument("--num-roots", type=int, default=3000,
help="Expected number of sampled root nodes when using random walk sampler")
parser.add_argument("--length", type=int, default=2,
help="The length of random walk when using random walk sampler")
parser.add_argument("--num-repeat", type=int, default=50,
help="Number of times of repeating sampling one node to estimate edge / node probability")
# model params
parser.add_argument("--n-hidden", type=int, default=512,
help="Number of hidden gcn units")
parser.add_argument("--arch", type=str, default="1-0-1-0",
help="Network architecture. 1 means an order-1 layer (self feature plus 1-hop neighbor "
"feature), and 0 means an order-0 layer (self feature only)")
parser.add_argument("--dropout", type=float, default=0,
help="Dropout rate")
parser.add_argument("--no-batch-norm", action='store_true',
help="Whether to use batch norm")
parser.add_argument("--aggr", type=str, default="concat", choices=['mean', 'concat'],
help="How to aggregate the self feature and neighbor features")
# training params
parser.add_argument("--n-epochs", type=int, default=100,
help="Number of training epochs")
parser.add_argument("--lr", type=float, default=0.01,
help="Learning rate")
parser.add_argument("--val-every", type=int, default=1,
help="Frequency of evaluation on the validation set in number of epochs")
parser.add_argument("--use-val", action='store_true',
help="whether to use validated best model to test")
parser.add_argument("--log-dir", type=str, default='none',
help="Log file will be saved to log/{dataset}/{log_dir}")
args = parser.parse_args()
parser = argparse.ArgumentParser(description='GraphSAINT')
parser.add_argument("--task", type=str, default="ppi_n", help="type of tasks")
parser.add_argument("--online", dest='online', action='store_true', help="sampling method in training phase")
parser.add_argument("--gpu", type=int, default=0, help="the gpu index")
task = parser.parse_args().task
args = argparse.Namespace(**CONFIG[task])
args.online = parser.parse_args().online
args.gpu = parser.parse_args().gpu
print(args) print(args)
main(args) main(args, task=task)
...@@ -57,6 +57,10 @@ def evaluate(model, g, labels, mask, multilabel=False): ...@@ -57,6 +57,10 @@ def evaluate(model, g, labels, mask, multilabel=False):
# load data of GraphSAINT and convert them to the format of dgl # load data of GraphSAINT and convert them to the format of dgl
def load_data(args, multilabel): def load_data(args, multilabel):
if not os.path.exists('graphsaintdata') and not os.path.exists('data'):
raise ValueError("The directory graphsaintdata does not exist!")
elif os.path.exists('graphsaintdata') and not os.path.exists('data'):
os.rename('graphsaintdata', 'data')
prefix = "data/{}".format(args.dataset) prefix = "data/{}".format(args.dataset)
DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g']) DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment