"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "2d80e99b7a7cb69d2abbdc862901d8f0a1ed4bed"
Unverified Commit 33175226 authored by Hengrui Zhang's avatar Hengrui Zhang Committed by GitHub
Browse files

[Example] Add implementation of InfoGraph (#2644)



* Add implementation of unsupervised model

* [Doc] Update Implementor's information

* [doc] add index of infograph

* [Feature] QM9_v2 Dataset Support

* fix a typo

* move qm9dataset from data to examples

* Update qm9_v2.py

* Infograph -> InfoGraph

* add implementation and results of semi-supervised model

* Update README.md

* Update README.md

* fix a typo

* Remove the duplicated links.

* fix some typos

* fix typos

* update model.py

* update collate fn

* remove unused functions

* Update model.py

* add device option

* Update evaluate_embedding.py

* Update evaluate_embedding.py

* Update unsupervised.py

* Fix typos

* fix bugs

* Update README.md
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent a1f59c3b
...@@ -13,6 +13,7 @@ The folder contains example implementations of selected research papers related ...@@ -13,6 +13,7 @@ The folder contains example implementations of selected research papers related
| [Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges](#mwe) | :heavy_check_mark: | | | | :heavy_check_mark: | | [Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges](#mwe) | :heavy_check_mark: | | | | :heavy_check_mark: |
| [SIGN: Scalable Inception Graph Neural Networks](#sign) | :heavy_check_mark: | | | | :heavy_check_mark: | | [SIGN: Scalable Inception Graph Neural Networks](#sign) | :heavy_check_mark: | | | | :heavy_check_mark: |
| [Strategies for Pre-training Graph Neural Networks](#prestrategy) | | | :heavy_check_mark: | | | | [Strategies for Pre-training Graph Neural Networks](#prestrategy) | | | :heavy_check_mark: | | |
| [InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization](#infograph) | | | :heavy_check_mark: | | |
| [Graph Neural Networks with convolutional ARMA filters](#arma) | :heavy_check_mark: | | | | | | [Graph Neural Networks with convolutional ARMA filters](#arma) | :heavy_check_mark: | | | | |
| [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](#appnp) | :heavy_check_mark: | | | | | | [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](#appnp) | :heavy_check_mark: | | | | |
| [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](#clustergcn) | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: | | [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](#clustergcn) | :heavy_check_mark: | | | :heavy_check_mark: | :heavy_check_mark: |
...@@ -87,125 +88,100 @@ The folder contains example implementations of selected research papers related ...@@ -87,125 +88,100 @@ The folder contains example implementations of selected research papers related
- <a name="grand"></a> Feng et al. Graph Random Neural Network for Semi-Supervised Learning on Graphs. [Paper link](https://arxiv.org/abs/2005.11079). - <a name="grand"></a> Feng et al. Graph Random Neural Network for Semi-Supervised Learning on Graphs. [Paper link](https://arxiv.org/abs/2005.11079).
- Example code: [PyTorch](../examples/pytorch/grand) - Example code: [PyTorch](../examples/pytorch/grand)
- Tags: semi-supervised node classification, simplifying graph convolution, data augmentation - Tags: semi-supervised node classification, simplifying graph convolution, data augmentation
- <a name="hgt"></a> Hu et al. Heterogeneous Graph Transformer. [Paper link](https://arxiv.org/abs/2003.01332). - <a name="hgt"></a> Hu et al. Heterogeneous Graph Transformer. [Paper link](https://arxiv.org/abs/2003.01332).
- Example code: [PyTorch](../examples/pytorch/hgt) - Example code: [PyTorch](../examples/pytorch/hgt)
- Tags: dynamic heterogeneous graphs, large-scale, node classification, link prediction - Tags: dynamic heterogeneous graphs, large-scale, node classification, link prediction
- <a name="mwe"></a> Chen. Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges. [Paper link](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf). - <a name="mwe"></a> Chen. Graph Convolutional Networks for Graphs with Multi-Dimensionally Weighted Edges. [Paper link](https://cims.nyu.edu/~chenzh/files/GCN_with_edge_weights.pdf).
- Example code: [PyTorch on ogbn-proteins](../examples/pytorch/ogb/ogbn-proteins) - Example code: [PyTorch on ogbn-proteins](../examples/pytorch/ogb/ogbn-proteins)
- Tags: node classification, weighted graphs, OGB - Tags: node classification, weighted graphs, OGB
- <a name="sign"></a> Frasca et al. SIGN: Scalable Inception Graph Neural Networks. [Paper link](https://arxiv.org/abs/2004.11198). - <a name="sign"></a> Frasca et al. SIGN: Scalable Inception Graph Neural Networks. [Paper link](https://arxiv.org/abs/2004.11198).
- Example code: [PyTorch on ogbn-arxiv/products/mag](../examples/pytorch/ogb/sign), [PyTorch](../examples/pytorch/sign) - Example code: [PyTorch on ogbn-arxiv/products/mag](../examples/pytorch/ogb/sign), [PyTorch](../examples/pytorch/sign)
- Tags: node classification, OGB, large-scale, heterogeneous graphs - Tags: node classification, OGB, large-scale, heterogeneous graphs
- <a name="prestrategy"></a> Hu et al. Strategies for Pre-training Graph Neural Networks. [Paper link](https://arxiv.org/abs/1905.12265). - <a name="prestrategy"></a> Hu et al. Strategies for Pre-training Graph Neural Networks. [Paper link](https://arxiv.org/abs/1905.12265).
- Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration) - Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
- Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction - Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction
- <a name="gnnfilm"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192). - <a name="gnnfilm"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192).
- Example code: [Pytorch](../examples/pytorch/GNN-FiLM) - Example code: [Pytorch](../examples/pytorch/GNN-FiLM)
- Tags: multi-relational graphs, hypernetworks, GNN architectures - Tags: multi-relational graphs, hypernetworks, GNN architectures
- <a name="gxn"></a> Li, Maosen, et al. Graph Cross Networks with Vertex Infomax Pooling. [Paper link](https://arxiv.org/abs/2010.01804). - <a name="gxn"></a> Li, Maosen, et al. Graph Cross Networks with Vertex Infomax Pooling. [Paper link](https://arxiv.org/abs/2010.01804).
- Example code: [Pytorch](../examples/pytorch/gxn) - Example code: [Pytorch](../examples/pytorch/gxn)
- Tags: pooling, graph classification - Tags: pooling, graph classification
- <a name="dagnn"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296). - <a name="dagnn"></a> Liu et al. Towards Deeper Graph Neural Networks. [Paper link](https://arxiv.org/abs/2007.09296).
- Example code: [Pytorch](../examples/pytorch/dagnn) - Example code: [Pytorch](../examples/pytorch/dagnn)
- Tags: over-smoothing, node classification - Tags: over-smoothing, node classification
## 2019 ## 2019
- <a name="infograph"></a> Sun et al. InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. [Paper link](https://arxiv.org/abs/1908.01000).
- Example code: [PyTorch](../examples/pytorch/infograph)
- Tags: semi-supervised graph regression, unsupervised graph classification
- <a name="arma"></a> Bianchi et al. Graph Neural Networks with Convolutional ARMA Filters. [Paper link](https://arxiv.org/abs/1901.01343). - <a name="arma"></a> Bianchi et al. Graph Neural Networks with Convolutional ARMA Filters. [Paper link](https://arxiv.org/abs/1901.01343).
- Example code: [PyTorch](../examples/pytorch/arma) - Example code: [PyTorch](../examples/pytorch/arma)
- Tags: node classification - Tags: node classification
- <a name="appnp"></a> Klicpera et al. Predict then Propagate: Graph Neural Networks meet Personalized PageRank. [Paper link](https://arxiv.org/abs/1810.05997). - <a name="appnp"></a> Klicpera et al. Predict then Propagate: Graph Neural Networks meet Personalized PageRank. [Paper link](https://arxiv.org/abs/1810.05997).
- Example code: [PyTorch](../examples/pytorch/appnp), [MXNet](../examples/mxnet/appnp) - Example code: [PyTorch](../examples/pytorch/appnp), [MXNet](../examples/mxnet/appnp)
- Tags: node classification - Tags: node classification
- <a name="clustergcn"></a> Chiang et al. Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1905.07953). - <a name="clustergcn"></a> Chiang et al. Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1905.07953).
- Example code: [PyTorch](../examples/pytorch/cluster_gcn), [PyTorch-based GraphSAGE variant on OGB](../examples/pytorch/ogb/cluster-sage), [PyTorch-based GAT variant on OGB](../examples/pytorch/ogb/cluster-gat) - Example code: [PyTorch](../examples/pytorch/cluster_gcn), [PyTorch-based GraphSAGE variant on OGB](../examples/pytorch/ogb/cluster-sage), [PyTorch-based GAT variant on OGB](../examples/pytorch/ogb/cluster-gat)
- Tags: graph partition, node classification, large-scale, OGB, sampling - Tags: graph partition, node classification, large-scale, OGB, sampling
- <a name="dgi"></a> Veličković et al. Deep Graph Infomax. [Paper link](https://arxiv.org/abs/1809.10341). - <a name="dgi"></a> Veličković et al. Deep Graph Infomax. [Paper link](https://arxiv.org/abs/1809.10341).
- Example code: [PyTorch](../examples/pytorch/dgi), [TensorFlow](../examples/tensorflow/dgi) - Example code: [PyTorch](../examples/pytorch/dgi), [TensorFlow](../examples/tensorflow/dgi)
- Tags: unsupervised learning, node classification - Tags: unsupervised learning, node classification
- <a name="diffpool"></a> Ying et al. Hierarchical Graph Representation Learning with Differentiable Pooling. [Paper link](https://arxiv.org/abs/1806.08804). - <a name="diffpool"></a> Ying et al. Hierarchical Graph Representation Learning with Differentiable Pooling. [Paper link](https://arxiv.org/abs/1806.08804).
- Example code: [PyTorch](../examples/pytorch/diffpool) - Example code: [PyTorch](../examples/pytorch/diffpool)
- Tags: pooling, graph classification, graph coarsening - Tags: pooling, graph classification, graph coarsening
- <a name="gatne-t"></a> Cen et al. Representation Learning for Attributed Multiplex Heterogeneous Network. [Paper link](https://arxiv.org/abs/1905.01669v2). - <a name="gatne-t"></a> Cen et al. Representation Learning for Attributed Multiplex Heterogeneous Network. [Paper link](https://arxiv.org/abs/1905.01669v2).
- Example code: [PyTorch](../examples/pytorch/GATNE-T) - Example code: [PyTorch](../examples/pytorch/GATNE-T)
- Tags: heterogeneous graphs, link prediction, large-scale - Tags: heterogeneous graphs, link prediction, large-scale
- <a name="gin"></a> Xu et al. How Powerful are Graph Neural Networks? [Paper link](https://arxiv.org/abs/1810.00826). - <a name="gin"></a> Xu et al. How Powerful are Graph Neural Networks? [Paper link](https://arxiv.org/abs/1810.00826).
- Example code: [PyTorch on graph classification](../examples/pytorch/gin), [PyTorch on node classification](../examples/pytorch/model_zoo/citation_network), [PyTorch on ogbg-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/ogbg_ppa), [MXNet](../examples/mxnet/gin) - Example code: [PyTorch on graph classification](../examples/pytorch/gin), [PyTorch on node classification](../examples/pytorch/model_zoo/citation_network), [PyTorch on ogbg-ppa](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/ogbg_ppa), [MXNet](../examples/mxnet/gin)
- Tags: graph classification, node classification, OGB - Tags: graph classification, node classification, OGB
- <a name="graphwriter"></a> Koncel-Kedziorski et al. Text Generation from Knowledge Graphs with Graph Transformers. [Paper link](https://arxiv.org/abs/1904.02342). - <a name="graphwriter"></a> Koncel-Kedziorski et al. Text Generation from Knowledge Graphs with Graph Transformers. [Paper link](https://arxiv.org/abs/1904.02342).
- Example code: [PyTorch](../examples/pytorch/graphwriter) - Example code: [PyTorch](../examples/pytorch/graphwriter)
- Tags: knowledge graph, text generation - Tags: knowledge graph, text generation
- <a name="han"></a> Wang et al. Heterogeneous Graph Attention Network. [Paper link](https://arxiv.org/abs/1903.07293). - <a name="han"></a> Wang et al. Heterogeneous Graph Attention Network. [Paper link](https://arxiv.org/abs/1903.07293).
- Example code: [PyTorch](../examples/pytorch/han) - Example code: [PyTorch](../examples/pytorch/han)
- Tags: heterogeneous graphs, node classification - Tags: heterogeneous graphs, node classification
- <a name="lgnn"></a> Chen et al. Supervised Community Detection with Line Graph Neural Networks. [Paper link](https://arxiv.org/abs/1705.08415). - <a name="lgnn"></a> Chen et al. Supervised Community Detection with Line Graph Neural Networks. [Paper link](https://arxiv.org/abs/1705.08415).
- Example code: [PyTorch](../examples/pytorch/line_graph) - Example code: [PyTorch](../examples/pytorch/line_graph)
- Tags: line graph, community detection - Tags: line graph, community detection
- <a name="sgc"></a> Wu et al. Simplifying Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1902.07153). - <a name="sgc"></a> Wu et al. Simplifying Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1902.07153).
- Example code: [PyTorch](../examples/pytorch/sgc), [MXNet](../examples/mxnet/sgc) - Example code: [PyTorch](../examples/pytorch/sgc), [MXNet](../examples/mxnet/sgc)
- Tags: node classification - Tags: node classification
- <a name="dgcnnpoint"></a> Wang et al. Dynamic Graph CNN for Learning on Point Clouds. [Paper link](https://arxiv.org/abs/1801.07829). - <a name="dgcnnpoint"></a> Wang et al. Dynamic Graph CNN for Learning on Point Clouds. [Paper link](https://arxiv.org/abs/1801.07829).
- Example code: [PyTorch](../examples/pytorch/pointcloud/edgeconv) - Example code: [PyTorch](../examples/pytorch/pointcloud/edgeconv)
- Tags: point cloud classification - Tags: point cloud classification
- <a name="scenegraph"></a> Zhang et al. Graphical Contrastive Losses for Scene Graph Parsing. [Paper link](https://arxiv.org/abs/1903.02728). - <a name="scenegraph"></a> Zhang et al. Graphical Contrastive Losses for Scene Graph Parsing. [Paper link](https://arxiv.org/abs/1903.02728).
- Example code: [MXNet](../examples/mxnet/scenegraph) - Example code: [MXNet](../examples/mxnet/scenegraph)
- Tags: scene graph extraction - Tags: scene graph extraction
- <a name="settrans"></a> Lee et al. Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. [Paper link](https://arxiv.org/abs/1810.00825). - <a name="settrans"></a> Lee et al. Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks. [Paper link](https://arxiv.org/abs/1810.00825).
- Pooling module: [PyTorch encoder](https://docs.dgl.ai/api/python/nn.pytorch.html#settransformerencoder), [PyTorch decoder](https://docs.dgl.ai/api/python/nn.pytorch.html#settransformerdecoder) - Pooling module: [PyTorch encoder](https://docs.dgl.ai/api/python/nn.pytorch.html#settransformerencoder), [PyTorch decoder](https://docs.dgl.ai/api/python/nn.pytorch.html#settransformerdecoder)
- Tags: graph classification - Tags: graph classification
- <a name="wln"></a> Coley et al. A graph-convolutional neural network model for the prediction of chemical reactivity. [Paper link](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract). - <a name="wln"></a> Coley et al. A graph-convolutional neural network model for the prediction of chemical reactivity. [Paper link](https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract).
- Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/reaction_prediction/rexgen_direct) - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/reaction_prediction/rexgen_direct)
- Tags: molecules, reaction prediction - Tags: molecules, reaction prediction
- <a name="mgcn"></a> Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. [Paper link](https://arxiv.org/abs/1906.11081). - <a name="mgcn"></a> Lu et al. Molecular Property Prediction: A Multilevel Quantum Interactions Modeling Perspective. [Paper link](https://arxiv.org/abs/1906.11081).
- Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy) - Example code: [PyTorch](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/alchemy)
- Tags: molecules, quantum chemistry - Tags: molecules, quantum chemistry
- <a name="attentivefp"></a> Xiong et al. Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism. [Paper link](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959). - <a name="attentivefp"></a> Xiong et al. Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism. [Paper link](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959).
- Example code: [PyTorch (with attention visualization)](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/pubchem_aromaticity), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration) - Example code: [PyTorch (with attention visualization)](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/pubchem_aromaticity), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
- Tags: molecules, molecular property prediction - Tags: molecules, molecular property prediction
- <a name="rotate"></a> Sun et al. RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space. [Paper link](https://arxiv.org/pdf/1902.10197.pdf). - <a name="rotate"></a> Sun et al. RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space. [Paper link](https://arxiv.org/pdf/1902.10197.pdf).
- Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html) - Example code: [PyTorch](https://github.com/awslabs/dgl-ke/tree/master/examples), [PyTorch for custom data](https://aws-dglke.readthedocs.io/en/latest/commands.html)
- Tags: knowledge graph embedding - Tags: knowledge graph embedding
- <a name="mixhop"></a> Abu-El-Haija et al. MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing. [Paper link](https://arxiv.org/abs/1905.00067). - <a name="mixhop"></a> Abu-El-Haija et al. MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing. [Paper link](https://arxiv.org/abs/1905.00067).
- Example code: [PyTorch](../examples/pytorch/mixhop) - Example code: [PyTorch](../examples/pytorch/mixhop)
- Tags: node classification - Tags: node classification
- <a name="sagpool"></a> Lee, Junhyun, et al. Self-Attention Graph Pooling. [Paper link](https://arxiv.org/abs/1904.08082). - <a name="sagpool"></a> Lee, Junhyun, et al. Self-Attention Graph Pooling. [Paper link](https://arxiv.org/abs/1904.08082).
- Example code: [PyTorch](../examples/pytorch/sagpool) - Example code: [PyTorch](../examples/pytorch/sagpool)
- Tags: graph classification, pooling - Tags: graph classification, pooling
- <a name="hgp-sl"></a> Zhang, Zhen, et al. Hierarchical Graph Pooling with Structure Learning. [Paper link](https://arxiv.org/abs/1911.05954). - <a name="hgp-sl"></a> Zhang, Zhen, et al. Hierarchical Graph Pooling with Structure Learning. [Paper link](https://arxiv.org/abs/1911.05954).
- Example code: [PyTorch](../examples/pytorch/hgp_sl) - Example code: [PyTorch](../examples/pytorch/hgp_sl)
- Tags: graph classification, pooling - Tags: graph classification, pooling
- <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652). - <a name='hardgat'></a> Gao, Hongyang, et al. Graph Representation Learning via Hard and Channel-Wise Attention Networks [Paper link](https://arxiv.org/abs/1907.04652).
- Example code: [Pytorch](../examples/pytorch/hardgat) - Example code: [Pytorch](../examples/pytorch/hardgat)
- Tags: node classification, graph attention - Tags: node classification, graph attention
- <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108). - <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108).
- Example code: [Pytorch](../examples/pytorch/NGCF) - Example code: [Pytorch](../examples/pytorch/NGCF)
- Tags: Collaborative Filtering, Recommendation, Graph Neural Network - Tags: Collaborative Filtering, Recommendation, Graph Neural Network
......
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
This DGL example implements the GNN model proposed in the paper [Graph Random Neural Network for Semi-Supervised Learning on Graphs]( https://arxiv.org/abs/2005.11079). This DGL example implements the GNN model proposed in the paper [Graph Random Neural Network for Semi-Supervised Learning on Graphs]( https://arxiv.org/abs/2005.11079).
Paper link: https://arxiv.org/abs/2005.11079
Author's code: https://github.com/THUDM/GRAND Author's code: https://github.com/THUDM/GRAND
Contributor: Hengrui Zhang ([@hengruizhang98](https://github.com/hengruizhang98)) ## Example Implementor
This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab.
## Dependecies ## Dependencies
- Python 3.7 - Python 3.7
- PyTorch 1.7.1 - PyTorch 1.7.1
- numpy
- dgl 0.5.3 - dgl 0.5.3
## Dataset ## Dataset
......
# DGL Implementation of InfoGraph
This DGL example implements the model proposed in the paper [InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization](https://arxiv.org/abs/1908.01000).
Author's code: https://github.com/fanyun-sun/InfoGraph
## Example Implementor
This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang98) when he was an applied scientist intern at AWS Shanghai AI Lab.
## Dependencies
- Python 3.7
- PyTorch 1.7.1
- dgl 0.6.0
## Datasets
##### Unsupervised Graph Classification Dataset:
'MUTAG', 'PTC', 'IMDBBINARY'(IMDB-B), 'IMDBMULTI'(IMDB-M), 'REDDITBINARY'(RDT-B), 'REDDITMULTI5K'(RDT-M5K) of dgl.data.GINDataset.
| Dataset | MUTAG | PTC | RDT-B | RDT-M5K | IMDB-B | IMDB-M |
| --------------- | ----- | ----- | ------ | ------- | ------ | ------ |
| # Graphs | 188 | 344 | 2000 | 4999 | 1000 | 1500 |
| # Classes | 2 | 2 | 2 | 5 | 2 | 3 |
| Avg. Graph Size | 17.93 | 14.29 | 429.63 | 508.52 | 19.77 | 13.00 |
**Semi-supervised Graph Regression Dataset:**
QM9 dataset for graph property prediction (regression)
| Dataset | # Graphs | # Regression Tasks |
| ------- | -------- | ------------------ |
| QM9 | 130,831 | 12 |
The 12 tasks are:
| Keys | Description |
| ----- | :----------------------------------------- |
| mu | Dipole moment |
| alpha | Isotropic polarizability |
| homo | Highest occupied molecular orbital energ |
| lumo | Lowest unoccupied molecular orbital energy |
| gap | Gap between 'homo' and 'lumo' |
| r2 | Electronic spatial extent |
| zpve | Zero point vibrational energy |
| U0 | Internal energy at 0K |
| U | Internal energy at 298.15K |
| H | Enthalpy at 298.15K |
| G | Free energy at 298.15K |
| Cv | Heat capavity at 298.15K |
## Arguments
##### Unsupervised Graph Classification:
###### Dataset options
```
--dataname str The graph dataset name. Default is 'MUTAG'.
```
###### GPU options
```
--gpu int GPU index. Default is -1, using CPU.
```
###### Training options
```
--epochs int Number of training periods. Default is 20.
--batch_size int Size of a training batch. Default is 128.
--lr float Adam optimizer learning rate. Default is 0.01.
--log_interval int Interval bettwen two evaluations. Default is 1.
```
###### Model options
```
--n_layers int Number of GIN layers. Default is 3.
--hid_dim int Dimension of hidden layers. Default is 32.
```
##### Semi-supervised Graph Regression:
###### Dataset options
```
--target str The regression Task. Default is 'mu'.
--train_num int Number of supervised examples. Default is 5000.
```
###### GPU options
```
--gpu int GPU index. Default is -1, using CPU.
```
###### Training options
```
--epochs int Number of training periods. Default is 200.
--batch_size int Size of a training batch. Default is 20.
--val_batch_size int Size of a validation batch. Default is 100.
--lr float Adam optimizer learning rate. Default is 0.001.
```
###### Model options
```
--hid_dim int Dimension of hidden layers. Default is 64.
--reg int Regularization weight. Default is 0.001.
```
## How to run examples
Training and testing unsupervised model on MUTAG.
(As graphs in these datasets are quite small and sparse, moving graphs from cpu to gpu would take a longer time than training, we recommend using **cpu** for these datasets).
```bash
# MUTAG:
python unsupervised.py --dataname MUTAG --n_layers 4 --hid_dim 32
```
Replace 'MUTAG' with dataname in ['MUTAG', 'PTC', 'IMDBBINARY', 'IMDBMULTI', 'REDDITBINARY', 'REDDITMULTI5K'] if you'd like to try other datasets.
Training and testing semi-supervised model on QM9 for graph property 'mu' with gpu.
```bash
# QM9:
python semisupervised.py --gpu 0 --target mu
```
Replace 'mu' with other target names above.
## Performance
The hyperparameter setting in our implementation is identical to that reported in the paper.
##### Unsupervised Graph Classification:
| Dataset | MUTAG | PTC | RDT-B | RDT-M5K | IMDB-B | IMDB-M |
| :---------------: | :---: | :---: | :---: | ------- | ------ | ------ |
| Accuracy Reported | 89.01 | 61.65 | 82.50 | 53.46 | 73.03 | 49.69 |
| DGL | 89.88 | 63.54 | 88.50 | 56.27 | 72.70 | 50.13 |
* REDDIT-M dataset would take a quite long time to load and evaluate.
##### Semisupervised Graph Regression on QM9:
Here we only provide the results of 'mu', 'alpha', 'homo'.
| Target | mu | alpha | homo |
| :---------------: | :----: | :----: | :----: |
| MAE Reported | 0.3169 | 0.5444 | 0.0060 |
| The authors' code | 0.2411 | 0.5192 | 0.1560 |
| DGL | 0.2355 | 0.5483 | 0.1581 |
* The source of QM9 Dataset has changed so there's a gap between the MAE reported in the paper and that we reprodcued.
* See this [issue](https://github.com/fanyun-sun/InfoGraph/issues/8) for authors' response.
''' Evaluate unsupervised embedding using a variety of basic classifiers. '''
''' Credit: https://github.com/fanyun-sun/InfoGraph '''
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV, StratifiedKFold
from sklearn.svm import SVC
import numpy as np
import torch
import torch.nn as nn
class LogReg(nn.Module):
def __init__(self, ft_in, nb_classes):
super(LogReg, self).__init__()
self.fc = nn.Linear(ft_in, nb_classes)
def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self, seq):
ret = self.fc(seq)
return ret
def logistic_classify(x, y, device = 'cpu'):
nb_classes = np.unique(y).shape[0]
xent = nn.CrossEntropyLoss()
hid_units = x.shape[1]
accs = []
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
for train_index, test_index in kf.split(x, y):
train_embs, test_embs = x[train_index], x[test_index]
train_lbls, test_lbls= y[train_index], y[test_index]
train_embs, train_lbls = torch.from_numpy(train_embs).to(device), torch.from_numpy(train_lbls).to(device)
test_embs, test_lbls = torch.from_numpy(test_embs).to(device), torch.from_numpy(test_lbls).to(device)
log = LogReg(hid_units, nb_classes)
log = log.to(device)
opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
for it in range(100):
log.train()
opt.zero_grad()
logits = log(train_embs)
loss = xent(logits, train_lbls)
loss.backward()
opt.step()
logits = log(test_embs)
preds = torch.argmax(logits, dim=1)
acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
accs.append(acc.item())
return np.mean(accs)
def svc_classify(x, y, search):
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
accuracies = []
for train_index, test_index in kf.split(x, y):
x_train, x_test = x[train_index], x[test_index]
y_train, y_test = y[train_index], y[test_index]
if search:
params = {'C':[0.001, 0.01, 0.1, 1, 10, 100, 1000]}
classifier = GridSearchCV(SVC(), params, cv=5, scoring='accuracy', verbose=0)
else:
classifier = SVC(C=10)
classifier.fit(x_train, y_train)
accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
return np.mean(accuracies)
def evaluate_embedding(embeddings, labels, search=True, device = 'cpu'):
labels = preprocessing.LabelEncoder().fit_transform(labels)
x, y = np.array(embeddings), np.array(labels)
logreg_accuracy = logistic_classify(x, y, device)
print('LogReg', logreg_accuracy)
svc_accuracy = svc_classify(x, y, search)
print('svc', svc_accuracy)
return logreg_accuracy, svc_accuracy
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, ModuleList, Linear, GRU, ReLU, BatchNorm1d
from dgl.nn import GINConv, NNConv, Set2Set
from dgl.nn.pytorch.glob import SumPooling
from utils import global_global_loss_, local_global_loss_
''' Feedforward neural network'''
class FeedforwardNetwork(nn.Module):
'''
3-layer feed-forward neural networks with jumping connections
Parameters
-----------
in_dim: int
Input feature size.
hid_dim: int
Hidden feature size.
Functions
-----------
forward(feat):
feat: Tensor
[N * D], input features
'''
def __init__(self, in_dim, hid_dim):
super(FeedforwardNetwork, self).__init__()
self.block = Sequential(Linear(in_dim, hid_dim),
ReLU(),
Linear(hid_dim, hid_dim),
ReLU(),
Linear(hid_dim, hid_dim),
ReLU()
)
self.jump_con = Linear(in_dim, hid_dim)
def forward(self, feat):
block_out = self.block(feat)
jump_out = self.jump_con(feat)
out = block_out + jump_out
return out
''' Unsupervised Setting '''
class GINEncoder(nn.Module):
'''
Encoder based on dgl.nn.GINConv & dgl.nn.SumPooling
Parameters
-----------
in_dim: int
Input feature size.
hid_dim: int
Hidden feature size.
n_layer:
Number of GIN layers.
Functions
-----------
forward(graph, feat):
graph: DGLGraph
feat: Tensor
[N * D], node features
'''
def __init__(self, in_dim, hid_dim, n_layer):
super(GINEncoder, self).__init__()
self.n_layer = n_layer
self.convs = ModuleList()
self.bns = ModuleList()
for i in range(n_layer):
if i == 0:
n_in = in_dim
else:
n_in = hid_dim
n_out = hid_dim
block = Sequential(Linear(n_in, n_out),
ReLU(),
Linear(hid_dim, hid_dim)
)
conv = GINConv(apply_func = block, aggregator_type = 'sum')
bn = BatchNorm1d(hid_dim)
self.convs.append(conv)
self.bns.append(bn)
# sum pooling
self.pool = SumPooling()
def forward(self, graph, feat):
xs = []
x = feat
for i in range(self.n_layer):
x = F.relu(self.convs[i](graph, x))
x = self.bns[i](x)
xs.append(x)
local_emb = th.cat(xs, 1) # patch-level embedding
global_emb = self.pool(graph, local_emb) # graph-level embedding
return global_emb, local_emb
class InfoGraph(nn.Module):
r"""
InfoGraph model for unsupervised setting
Parameters
-----------
in_dim: int
Input feature size.
hid_dim: int
Hidden feature size.
n_layer: int
Number of the GNN encoder layers.
Functions
-----------
forward(graph):
graph: DGLGraph
"""
def __init__(self, in_dim, hid_dim, n_layer):
super(InfoGraph, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
self.n_layer = n_layer
embedding_dim = hid_dim * n_layer
self.encoder = GINEncoder(in_dim, hid_dim, n_layer)
self.local_d = FeedforwardNetwork(embedding_dim, embedding_dim) # local discriminator (node-level)
self.global_d = FeedforwardNetwork(embedding_dim, embedding_dim) # global discriminator (graph-level)
def get_embedding(self, graph, feat):
# get_embedding function for evaluation the learned embeddings
with th.no_grad():
global_emb, _ = self.encoder(graph, feat)
return global_emb
def forward(self, graph, feat, graph_id):
global_emb, local_emb = self.encoder(graph, feat)
global_h = self.global_d(global_emb) # global hidden representation
local_h = self.local_d(local_emb) # local hidden representation
loss = local_global_loss_(local_h, global_h, graph_id)
return loss
''' Semisupervised Setting '''
class NNConvEncoder(nn.Module):
'''
Encoder based on dgl.nn.NNConv & GRU & dgl.nn.set2set pooling
Parameters
-----------
in_dim: int
Input feature size.
hid_dim: int
Hidden feature size.
Functions
-----------
forward(graph, nfeat, efeat):
graph: DGLGraph
nfeat: Tensor
[N * D1], node features
efeat: Tensor
[E * D2], edge features
'''
def __init__(self, in_dim, hid_dim):
super(NNConvEncoder, self).__init__()
self.lin0 = Linear(in_dim, hid_dim)
# mlp for edge convolution in NNConv
block = Sequential(Linear(5, 128), ReLU(), Linear(128, hid_dim * hid_dim))
self.conv = NNConv(hid_dim, hid_dim, edge_func = block, aggregator_type = 'mean', residual = False)
self.gru = GRU(hid_dim, hid_dim)
# set2set pooling
self.set2set = Set2Set(hid_dim, n_iters=3, n_layers=1)
def forward(self, graph, nfeat, efeat):
out = F.relu(self.lin0(nfeat))
h = out.unsqueeze(0)
feat_map = []
# Convolution layer number is 3
for i in range(3):
m = F.relu(self.conv(graph, out, efeat))
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
feat_map.append(out)
out = self.set2set(graph, out)
# out: global embedding, feat_map[-1]: local embedding
return out, feat_map[-1]
class InfoGraphS(nn.Module):
'''
InfoGraph* model for semi-supervised setting
Parameters
-----------
in_dim: int
Input feature size.
hid_dim: int
Hidden feature size.
Functions
-----------
forward(graph):
graph: DGLGraph
unsupforward(graph):
graph: DGLGraph
'''
def __init__(self, in_dim, hid_dim):
super(InfoGraphS, self).__init__()
self.sup_encoder = NNConvEncoder(in_dim, hid_dim)
self.unsup_encoder = NNConvEncoder(in_dim, hid_dim)
self.fc1 = Linear(2 * hid_dim, hid_dim)
self.fc2 = Linear(hid_dim, 1)
# unsupervised local discriminator and global discriminator for local-global infomax
self.unsup_local_d = FeedforwardNetwork(hid_dim, hid_dim)
self.unsup_global_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
# supervised global discriminator and unsupervised global discriminator for global-global infomax
self.sup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
self.unsup_d = FeedforwardNetwork(2 * hid_dim, hid_dim)
def forward(self, graph, nfeat, efeat):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
sup_global_pred = self.fc2(F.relu(self.fc1(sup_global_emb)))
sup_global_pred = sup_global_pred.view(-1)
return sup_global_pred
def unsup_forward(self, graph, nfeat, efeat, graph_id):
sup_global_emb, sup_local_emb = self.sup_encoder(graph, nfeat, efeat)
unsup_global_emb, unsup_local_emb = self.unsup_encoder(graph, nfeat, efeat)
g_enc = self.unsup_global_d(unsup_global_emb)
l_enc = self.unsup_local_d(unsup_local_emb)
sup_g_enc = self.sup_d(sup_global_emb)
unsup_g_enc = self.unsup_d(unsup_global_emb)
# Calculate loss
unsup_loss = local_global_loss_(l_enc, g_enc, graph_id)
con_loss = global_global_loss_(sup_g_enc, unsup_g_enc)
return unsup_loss, con_loss
import numpy as np
import os
from tqdm import tqdm
import torch as th
import dgl
from dgl.data.dgl_dataset import DGLDataset
from dgl.data.utils import download, load_graphs, _get_dgl_url, extract_archive
class QM9DatasetV2(DGLDataset):
r"""QM9 dataset for graph property prediction (regression)
This dataset consists of 130,831 molecules with 19 regression targets.
Node means atom and edge means bond.
Reference: `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
Atom features come from `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
Statistics:
- Number of graphs: 130,831
- Number of regression targets: 19
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Keys | Property | Description | Unit |
+========+==================================+===================================================================================+=============================================+
| mu | :math:`\mu` | Dipole moment | :math:`\textrm{D}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| alpha | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| homo | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| lumo | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| gap | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| r2 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| zpve | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U0_atom| :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| U_atom | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| H_atom | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| G_atom | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\textrm{eV}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| A | :math:`A` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| B | :math:`B` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| c | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+----------------------------------------
Parameters
----------
label_keys: list
Names of the regression property, which should be a subset of the keys in the table above.
If not provided, will load all the labels.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9DatasetV2(label_keys=['mu', 'alpha'])
>>> data.num_labels
>>> # make each graph dense
>>> data.to_dense()
>>> # iterate over the dataset
>>> for graph, labels in data:
... print(graph) # get information of each graph
... print(labels) # get labels of the corresponding graph
... # your code here...
>>>
"""
def __init__(self,
label_keys = None,
raw_dir=None,
force_reload=False,
verbose=True):
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_ver2.zip')
super(QM9DatasetV2, self).__init__(name='qm9_v2',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
print('begin loading dataset')
graphs, label_dict = load_graphs(os.path.join(self.raw_dir, 'qm9_v2.bin'))
self.graphs = graphs
if self.label_keys == None:
self.labels = np.stack([label_dict[key] for key in label_dict.keys()], axis=1)
else:
self.labels = np.stack([label_dict[key] for key in self.label_keys], axis=1)
def to_dense(self):
r""" Transfrom each graph to a dense graph and add additional edge attribute(distance between two atoms)
Note: This operation will deprecate graph.ndata['pos']
"""
n_graph = self.labels.shape[0]
for id in tqdm(range(n_graph), desc = 'processing graphs'):
graph = self.graphs[id]
n_nodes = graph.num_nodes()
row = th.arange(n_nodes, dtype = th.long)
col = th.arange(n_nodes, dtype = th.long)
row = row.view(-1,1).repeat(1, n_nodes).view(-1)
col = col.repeat(n_nodes)
src = graph.edges()[0]
dst = graph.edges()[1]
idx = src * n_nodes + dst
size = list(graph.edata['edge_attr'].size())
size[0] = n_nodes * n_nodes
edge_attr = graph.edata['edge_attr'].new_zeros(size)
edge_attr[idx] = graph.edata['edge_attr']
pos = graph.ndata['pos']
dist = th.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
new_edge_attr = th.cat([edge_attr, dist.type_as(edge_attr)], dim = -1)
new_graph = dgl.graph((row,col))
new_graph.ndata['attr'] = graph.ndata['attr']
new_graph.edata['edge_attr'] = new_edge_attr
new_graph = new_graph.remove_self_loop()
self.graphs[id] = new_graph
def download(self):
file_path = f'{self.raw_dir}/qm9_v2.zip'
if not os.path.exists(file_path):
download(self._url, path=file_path)
extract_archive(file_path, self.raw_dir, overwrite = True)
@property
def num_labels(self):
r"""
Returns
--------
int
Number of labels for each graph, i.e. number of prediction tasks.
"""
return self.labels.shape[1]
def __getitem__(self, idx):
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the atomic attributes
- ``edata['edge_attr']``: the bond attributes
Tensor
Property values of molecular graphs
"""
return self.graphs[idx], self.labels[idx]
def __len__(self):
r"""Number of graphs in the dataset.
Return
-------
int
"""
return self.labels.shape[0]
import numpy as np
import torch as th
import torch.nn.functional as F
import dgl
from dgl.dataloading import GraphDataLoader
from dgl.data.utils import Subset
from qm9_v2 import QM9DatasetV2
from model import InfoGraphS
import argparse
def argument():
parser = argparse.ArgumentParser(description='InfoGraphS')
# data source params
parser.add_argument('--target', type=str, default='mu', help='Choose regression task')
parser.add_argument('--train_num', type=int, default=5000, help='Size of training set')
# training params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index, default:-1, using CPU.')
parser.add_argument('--epochs', type=int, default=200, help='Training epochs.')
parser.add_argument('--batch_size', type=int, default=20, help='Training batch size.')
parser.add_argument('--val_batch_size', type=int, default=100, help='Validation batch size.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--wd', type=float, default=0, help='Weight decay.')
# model params
parser.add_argument('--hid_dim', type=int, default=64, help='Hidden layer dimensionality')
parser.add_argument('--reg', type=float, default=0.001, help='Regularization coefficient')
args = parser.parse_args()
# check cuda
if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu)
else:
args.device = 'cpu'
return args
def collate(samples):
''' collate function for building graph dataloader '''
# generate batched graphs and labels
graphs, targets = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
batched_targets = th.Tensor(targets)
n_graphs = len(graphs)
graph_id = th.arange(n_graphs)
graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
batched_graph.ndata['graph_id'] = graph_id
return batched_graph, batched_targets
def evaluate(model, loader, num, device):
error = 0
for graphs, targets in loader:
graphs = graphs.to(device)
nfeat, efeat = graphs.ndata['attr'], graphs.edata['edge_attr']
targets = targets.to(device)
error += (model(graphs, nfeat, efeat) - targets).abs().sum().item()
error = error / num
return error
if __name__ == '__main__':
# Step 1: Prepare graph data ===================================== #
args = argument()
label_keys = [args.target]
print(args)
dataset = QM9DatasetV2(label_keys)
dataset.to_dense()
graphs = dataset.graphs
# Train/Val/Test Splitting
N = len(graphs)
all_idx = np.arange(N)
np.random.shuffle(all_idx)
val_num = 10000
test_num = 10000
val_idx = all_idx[:val_num]
test_idx = all_idx[val_num : val_num + test_num]
train_idx = all_idx[val_num + test_num : val_num + test_num + args.train_num]
train_data = Subset(dataset, train_idx)
val_data = Subset(dataset, val_idx)
test_data = Subset(dataset, test_idx)
unsup_idx = all_idx[val_num + test_num:]
unsup_data = Subset(dataset, unsup_idx)
# generate supervised training dataloader and unsupervised training dataloader
train_loader = GraphDataLoader(train_data,
batch_size=args.batch_size,
collate_fn=collate,
drop_last=False,
shuffle=True)
unsup_loader = GraphDataLoader(unsup_data,
batch_size=args.batch_size,
collate_fn=collate,
drop_last=False,
shuffle=True)
# generate validation & testing dataloader
val_loader = GraphDataLoader(val_data,
batch_size=args.val_batch_size,
collate_fn=collate,
drop_last=False,
shuffle=True)
test_loader = GraphDataLoader(test_data,
batch_size=args.val_batch_size,
collate_fn=collate,
drop_last=False,
shuffle=True)
print('======== target = {} ========'.format(args.target))
mean = dataset.labels.mean().item()
std = dataset.labels.std().item()
print('mean = {:4f}'.format(mean))
print('std = {:4f}'.format(std))
in_dim = dataset[0][0].ndata['attr'].shape[1]
# Step 2: Create model =================================================================== #
model = InfoGraphS(in_dim, args.hid_dim)
model = model.to(args.device)
# Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.7, patience=5, min_lr=0.000001
)
# Step 4: training epochs =============================================================== #
best_val_error = float('inf')
test_error = float('inf')
for epoch in range(args.epochs):
''' Training '''
model.train()
lr = scheduler.optimizer.param_groups[0]['lr']
iteration = 0
sup_loss_all = 0
unsup_loss_all = 0
consis_loss_all = 0
for sup_data, unsup_data in zip(train_loader, unsup_loader):
sup_graph, sup_target = sup_data
unsup_graph, _ = unsup_data
sup_graph = sup_graph.to(args.device)
unsup_graph = unsup_graph.to(args.device)
sup_nfeat, sup_efeat = sup_graph.ndata['attr'], sup_graph.ndata['edge_attr']
unsup_nfeat, unsup_efeat, unsup_graph_id = unsup_graph.ndata['attr'],\
unsup_graph.edata['edge_attr'], unsup_graph.edata['graph_id']
sup_target = sup_target
sup_target = sup_target.to(args.device)
optimizer.zero_grad()
sup_loss = F.mse_loss(model(sup_graph, sup_nfeat, sup_efeat), sup_target)
unsup_loss, consis_loss = model.unsup_forward(unsup_graph, unsup_nfeat, unsup_efeat, unsup_graph_id)
loss = sup_loss + unsup_loss + args.reg * consis_loss
loss.backward()
sup_loss_all += sup_loss.item()
unsup_loss_all += unsup_loss.item()
consis_loss_all += consis_loss.item()
optimizer.step()
print('Epoch: {}, Sup_Loss: {:4f}, Unsup_loss: {:.4f}, Consis_loss: {:.4f}' \
.format(epoch, sup_loss_all, unsup_loss_all, consis_loss_all))
model.eval()
val_error = evaluate(model, val_loader, val_num, args.device)
scheduler.step(val_error)
if val_error < best_val_error:
best_val_error = val_error
test_error = evaluate(model, test_loader, test_num, args.device)
print('Epoch: {}, LR: {}, val_error: {:.4f}, best_test_error: {:.4f}' \
.format(epoch, lr, val_error, test_error))
import torch as th
import dgl
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader
from model import InfoGraph
from evaluate_embedding import evaluate_embedding
import argparse
def argument():
parser = argparse.ArgumentParser(description='InfoGraph')
# data source params
parser.add_argument('--dataname', type=str, default='MUTAG', help='Name of dataset.')
# training params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index, default:-1, using CPU.')
parser.add_argument('--epochs', type=int, default=20, help='Training epochs.')
parser.add_argument('--batch_size', type=int, default=128, help='Training batch size.')
parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.')
parser.add_argument('--log_interval', type=int, default=1, help='Interval between two evaluations.')
# model params
parser.add_argument('--n_layers', type=int, default=3, help='Number of graph convolution layers before each pooling.')
parser.add_argument('--hid_dim', type=int, default=32, help='Hidden layer dimensionalities.')
args = parser.parse_args()
# check cuda
if args.gpu != -1 and th.cuda.is_available():
args.device = 'cuda:{}'.format(args.gpu)
else:
args.device = 'cpu'
return args
def collate(samples):
''' collate function for building graph dataloader'''
graphs, labels = map(list, zip(*samples))
# generate batched graphs and labels
batched_graph = dgl.batch(graphs)
batched_labels = th.tensor(labels)
n_graphs = len(graphs)
graph_id = th.arange(n_graphs)
graph_id = dgl.broadcast_nodes(batched_graph, graph_id)
batched_graph.ndata['graph_id'] = graph_id
return batched_graph, batched_labels
if __name__ == '__main__':
# Step 1: Prepare graph data ===================================== #
args = argument()
print(args)
# load dataset from dgl.data.GINDataset
dataset = GINDataset(args.dataname, self_loop = False)
# get graphs and labels
graphs, labels = map(list, zip(*dataset))
# generate a full-graph with all examples for evaluation
wholegraph = dgl.batch(graphs)
wholegraph.ndata['attr'] = wholegraph.ndata['attr'].to(th.float32)
# create dataloader for batch training
dataloader = GraphDataLoader(dataset,
batch_size=args.batch_size,
collate_fn=collate,
drop_last=False,
shuffle=True)
in_dim = wholegraph.ndata['attr'].shape[1]
# Step 2: Create model =================================================================== #
model = InfoGraph(in_dim, args.hid_dim, args.n_layers)
model = model.to(args.device)
# Step 3: Create training components ===================================================== #
optimizer = th.optim.Adam(model.parameters(), lr=args.lr)
print('===== Before training ======')
wholegraph = wholegraph.to(args.device)
wholefeat = wholegraph.ndata['attr']
emb = model.get_embedding(wholegraph, wholefeat).cpu()
res = evaluate_embedding(emb, labels, args.device)
''' Evaluate the initialized embeddings '''
''' using logistic regression and SVM(non-linear) '''
print('logreg {:4f}, svc {:4f}'.format(res[0], res[1]))
best_logreg = 0
best_logreg_epoch = 0
best_svc = 0
best_svc_epoch = 0
# Step 4: training epochs =============================================================== #
for epoch in range(args.epochs):
loss_all = 0
model.train()
for graph, label in dataloader:
graph = graph.to(args.device)
feat = graph.ndata['attr']
graph_id = graph.ndata['graph_id']
n_graph = label.shape[0]
optimizer.zero_grad()
loss = model(graph, feat, graph_id)
loss.backward()
optimizer.step()
loss_all += loss.item()
print('Epoch {}, Loss {:.4f}'.format(epoch, loss_all))
if epoch % args.log_interval == 0:
# evaluate embeddings
model.eval()
emb = model.get_embedding(wholegraph, wholefeat).cpu()
res = evaluate_embedding(emb, labels, args.device)
if res[0] > best_logreg:
best_logreg = res[0]
best_logreg_epoch = epoch
if res[1] > best_svc:
best_svc = res[1]
best_svc_epoch = epoch
print('best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}'.format(best_logreg, best_logreg_epoch, best_svc, best_svc_epoch))
print('Training End')
print('best logreg {:4f} ,best svc {:4f}'.format(best_logreg, best_svc))
''' Credit: https://github.com/fanyun-sun/InfoGraph '''
import torch as th
import torch.nn.functional as F
import math
def get_positive_expectation(p_samples, average=True):
"""Computes the positive part of a JS Divergence.
Args:
p_samples: Positive samples.
average: Average the result over samples.
Returns:
th.Tensor
"""
log_2 = math.log(2.)
Ep = log_2 - F.softplus(- p_samples)
if average:
return Ep.mean()
else:
return Ep
def get_negative_expectation(q_samples, average=True):
"""Computes the negative part of a JS Divergence.
Args:
q_samples: Negative samples.
average: Average the result over samples.
Returns:
th.Tensor
"""
log_2 = math.log(2.)
Eq = F.softplus(-q_samples) + q_samples - log_2
if average:
return Eq.mean()
else:
return Eq
def local_global_loss_(l_enc, g_enc, graph_id):
num_graphs = g_enc.shape[0]
num_nodes = l_enc.shape[0]
device = g_enc.device
pos_mask = th.zeros((num_nodes, num_graphs)).to(device)
neg_mask = th.ones((num_nodes, num_graphs)).to(device)
for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1.
neg_mask[nodeidx][graphidx] = 0.
res = th.mm(l_enc, g_enc.t())
E_pos = get_positive_expectation(res * pos_mask, average=False).sum()
E_pos = E_pos / num_nodes
E_neg = get_negative_expectation(res * neg_mask, average=False).sum()
E_neg = E_neg / (num_nodes * (num_graphs - 1))
return E_neg - E_pos
def global_global_loss_(sup_enc, unsup_enc):
num_graphs = sup_enc.shape[0]
device = sup_enc.device
pos_mask = th.eye(num_graphs).to(device)
neg_mask = 1 - pos_mask
res = th.mm(sup_enc, unsup_enc.t())
E_pos = get_positive_expectation(res * pos_mask, average=False)
E_pos = (E_pos * pos_mask).sum() / pos_mask.sum()
E_neg = get_negative_expectation(res * neg_mask, average=False)
E_neg = (E_neg * neg_mask).sum() / neg_mask.sum()
return E_neg - E_pos
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment